Skip to content

Commit

Permalink
Changes back to Copy-capable ArrayBase and unchanged ArrayBase internals
Browse files Browse the repository at this point in the history
  • Loading branch information
akern40 committed Oct 6, 2024
1 parent 5168fb6 commit 8469495
Show file tree
Hide file tree
Showing 23 changed files with 124 additions and 164 deletions.
16 changes: 8 additions & 8 deletions benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ fn bench_row_iter(bench: &mut test::Bencher)
let a = Array::<f32, _>::zeros((1024, 1024));
let it = a.row(17);
bench.iter(|| {
for elt in it.clone() {
for elt in it {
black_box(elt);
}
})
Expand All @@ -788,7 +788,7 @@ fn bench_col_iter(bench: &mut test::Bencher)
let a = Array::<f32, _>::zeros((1024, 1024));
let it = a.column(17);
bench.iter(|| {
for elt in it.clone() {
for elt in it {
black_box(elt);
}
})
Expand Down Expand Up @@ -861,7 +861,7 @@ fn create_iter_4d(bench: &mut test::Bencher)
a.swap_axes(2, 1);
let v = black_box(a.view());

bench.iter(|| v.clone().into_iter());
bench.iter(|| v.into_iter());
}

#[bench]
Expand Down Expand Up @@ -1023,23 +1023,23 @@ fn into_dimensionality_ix1_ok(bench: &mut test::Bencher)
{
let a = Array::<f32, _>::zeros(Ix1(10));
let a = a.view();
bench.iter(|| a.clone().into_dimensionality::<Ix1>());
bench.iter(|| a.into_dimensionality::<Ix1>());
}

#[bench]
fn into_dimensionality_ix3_ok(bench: &mut test::Bencher)
{
let a = Array::<f32, _>::zeros(Ix3(10, 10, 10));
let a = a.view();
bench.iter(|| a.clone().into_dimensionality::<Ix3>());
bench.iter(|| a.into_dimensionality::<Ix3>());
}

#[bench]
fn into_dimensionality_ix3_err(bench: &mut test::Bencher)
{
let a = Array::<f32, _>::zeros(Ix3(10, 10, 10));
let a = a.view();
bench.iter(|| a.clone().into_dimensionality::<Ix2>());
bench.iter(|| a.into_dimensionality::<Ix2>());
}

#[bench]
Expand All @@ -1063,15 +1063,15 @@ fn into_dyn_ix3(bench: &mut test::Bencher)
{
let a = Array::<f32, _>::zeros(Ix3(10, 10, 10));
let a = a.view();
bench.iter(|| a.clone().into_dyn());
bench.iter(|| a.into_dyn());
}

#[bench]
fn into_dyn_ix5(bench: &mut test::Bencher)
{
let a = Array::<f32, _>::zeros(Ix5(2, 2, 2, 2, 2));
let a = a.view();
bench.iter(|| a.clone().into_dyn());
bench.iter(|| a.into_dyn());
}

#[bench]
Expand Down
2 changes: 1 addition & 1 deletion src/arraytraits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ where D: Dimension
{
let data = OwnedArcRepr(Arc::new(arr.data));
// safe because: equivalent unmoved data, ptr and dims remain valid
unsafe { ArrayBase::from_data_ptr(data, arr.aref.ptr).with_strides_dim(arr.aref.strides, arr.aref.dim) }
unsafe { ArrayBase::from_data_ptr(data, arr.ptr).with_strides_dim(arr.strides, arr.dim) }
}
}

Expand Down
23 changes: 11 additions & 12 deletions src/data_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ where A: Clone
let rcvec = &mut self_.data.0;
let a_size = mem::size_of::<A>() as isize;
let our_off = if a_size != 0 {
(self_.aref.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size
(self_.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size
} else {
0
};
let rvec = Arc::make_mut(rcvec);
unsafe {
self_.aref.ptr = rvec.as_nonnull_mut().offset(our_off);
self_.ptr = rvec.as_nonnull_mut().offset(our_off);
}
}

Expand All @@ -305,7 +305,7 @@ unsafe impl<A> Data for OwnedArcRepr<A>
Self::ensure_unique(&mut self_);
let data = Arc::try_unwrap(self_.data.0).ok().unwrap();
// safe because data is equivalent
unsafe { ArrayBase::from_data_ptr(data, self_.aref.ptr).with_strides_dim(self_.aref.strides, self_.aref.dim) }
unsafe { ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim) }
}

fn try_into_owned_nocopy<D>(self_: ArrayBase<Self, D>) -> Result<Array<Self::Elem, D>, ArrayBase<Self, D>>
Expand All @@ -314,14 +314,13 @@ unsafe impl<A> Data for OwnedArcRepr<A>
match Arc::try_unwrap(self_.data.0) {
Ok(owned_data) => unsafe {
// Safe because the data is equivalent.
Ok(ArrayBase::from_data_ptr(owned_data, self_.aref.ptr)
.with_strides_dim(self_.aref.strides, self_.aref.dim))
Ok(ArrayBase::from_data_ptr(owned_data, self_.ptr).with_strides_dim(self_.strides, self_.dim))
},
Err(arc_data) => unsafe {
// Safe because the data is equivalent; we're just
// reconstructing `self_`.
Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.aref.ptr)
.with_strides_dim(self_.aref.strides, self_.aref.dim))
Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.ptr)
.with_strides_dim(self_.strides, self_.dim))
},
}
}
Expand Down Expand Up @@ -624,9 +623,9 @@ where A: Clone
CowRepr::View(_) => {
let owned = array.to_owned();
array.data = CowRepr::Owned(owned.data);
array.aref.ptr = owned.aref.ptr;
array.aref.dim = owned.aref.dim;
array.aref.strides = owned.aref.strides;
array.ptr = owned.ptr;
array.dim = owned.dim;
array.strides = owned.strides;
}
CowRepr::Owned(_) => {}
}
Expand Down Expand Up @@ -687,7 +686,7 @@ unsafe impl<'a, A> Data for CowRepr<'a, A>
CowRepr::View(_) => self_.to_owned(),
CowRepr::Owned(data) => unsafe {
// safe because the data is equivalent so ptr, dims remain valid
ArrayBase::from_data_ptr(data, self_.aref.ptr).with_strides_dim(self_.aref.strides, self_.aref.dim)
ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim)
},
}
}
Expand All @@ -699,7 +698,7 @@ unsafe impl<'a, A> Data for CowRepr<'a, A>
CowRepr::View(_) => Err(self_),
CowRepr::Owned(data) => unsafe {
// safe because the data is equivalent so ptr, dims remain valid
Ok(ArrayBase::from_data_ptr(data, self_.aref.ptr).with_strides_dim(self_.aref.strides, self_.aref.dim))
Ok(ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim))
},
}
}
Expand Down
32 changes: 11 additions & 21 deletions src/free_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use core::marker::PhantomData;
#[allow(unused_imports)]
use std::compile_error;
use std::mem::{forget, size_of};
Expand Down Expand Up @@ -107,13 +106,10 @@ pub const fn aview0<A>(x: &A) -> ArrayView0<'_, A>
{
ArrayBase {
data: ViewRepr::new(),
aref: RefBase {
// Safe because references are always non-null.
ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) },
dim: Ix0(),
strides: Ix0(),
phantom: PhantomData::<<ViewRepr<&'_ A> as RawData>::Referent>,
},
// Safe because references are always non-null.
ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) },
dim: Ix0(),
strides: Ix0(),
}
}

Expand Down Expand Up @@ -148,13 +144,10 @@ pub const fn aview1<A>(xs: &[A]) -> ArrayView1<'_, A>
}
ArrayBase {
data: ViewRepr::new(),
aref: RefBase {
// Safe because references are always non-null.
ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) },
dim: Ix1(xs.len()),
strides: Ix1(1),
phantom: PhantomData::<<ViewRepr<&'_ A> as RawData>::Referent>,
},
// Safe because references are always non-null.
ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) },
dim: Ix1(xs.len()),
strides: Ix1(1),
}
}

Expand Down Expand Up @@ -207,12 +200,9 @@ pub const fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A>
};
ArrayBase {
data: ViewRepr::new(),
aref: RefBase {
ptr,
dim,
strides,
phantom: PhantomData::<<ViewRepr<&'_ A> as RawData>::Referent>,
},
ptr,
dim,
strides,
}
}

Expand Down
15 changes: 6 additions & 9 deletions src/impl_clone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@ impl<S: RawDataClone, D: Clone> Clone for ArrayBase<S, D>
let (data, ptr) = self.data.clone_with_ptr(self.ptr);
ArrayBase {
data,
aref: RefBase {
ptr,
dim: self.dim.clone(),
strides: self.strides.clone(),
phantom: self.phantom,
},
ptr,
dim: self.dim.clone(),
strides: self.strides.clone(),
}
}
}
Expand All @@ -34,9 +31,9 @@ impl<S: RawDataClone, D: Clone> Clone for ArrayBase<S, D>
fn clone_from(&mut self, other: &Self)
{
unsafe {
self.aref.ptr = self.data.clone_from_with_ptr(&other.data, other.ptr);
self.aref.dim.clone_from(&other.dim);
self.aref.strides.clone_from(&other.strides);
self.ptr = self.data.clone_from_with_ptr(&other.data, other.ptr);
self.dim.clone_from(&other.dim);
self.strides.clone_from(&other.strides);
}
}
}
Expand Down
8 changes: 2 additions & 6 deletions src/impl_cow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ where D: Dimension
fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D>
{
// safe because equivalent data
unsafe {
ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr)
.with_strides_dim(view.aref.strides, view.aref.dim)
}
unsafe { ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr).with_strides_dim(view.strides, view.dim) }
}
}

Expand All @@ -47,8 +44,7 @@ where D: Dimension
{
// safe because equivalent data
unsafe {
ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.aref.ptr)
.with_strides_dim(array.aref.strides, array.aref.dim)
ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.ptr).with_strides_dim(array.strides, array.dim)
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/impl_dyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ where S: Data<Elem = A>
pub fn insert_axis_inplace(&mut self, axis: Axis)
{
assert!(axis.index() <= self.ndim());
self.aref.dim = self.dim.insert_axis(axis);
self.aref.strides = self.strides.insert_axis(axis);
self.dim = self.dim.insert_axis(axis);
self.strides = self.strides.insert_axis(axis);
}

/// Collapses the array to `index` along the axis and removes the axis,
Expand All @@ -55,8 +55,8 @@ where S: Data<Elem = A>
pub fn index_axis_inplace(&mut self, axis: Axis, index: usize)
{
self.collapse_axis(axis, index);
self.aref.dim = self.dim.remove_axis(axis);
self.aref.strides = self.strides.remove_axis(axis);
self.dim = self.dim.remove_axis(axis);
self.strides = self.strides.remove_axis(axis);
}

/// Remove axes of length 1 and return the modified array.
Expand Down
19 changes: 6 additions & 13 deletions src/impl_internal_constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use core::marker::PhantomData;
use std::ptr::NonNull;

use crate::imp_prelude::*;
Expand All @@ -28,12 +27,9 @@ where S: RawData<Elem = A>
{
let array = ArrayBase {
data,
aref: RefBase {
ptr,
dim: Ix1(0),
strides: Ix1(1),
phantom: PhantomData::<S::Referent>,
},
ptr,
dim: Ix1(0),
strides: Ix1(1),
};
debug_assert!(array.pointer_is_inbounds());
array
Expand Down Expand Up @@ -62,12 +58,9 @@ where
debug_assert_eq!(strides.ndim(), dim.ndim());
ArrayBase {
data: self.data,
aref: RefBase {
ptr: self.aref.ptr,
dim,
strides,
phantom: self.aref.phantom,
},
ptr: self.ptr,
dim,
strides,
}
}
}
Loading

0 comments on commit 8469495

Please sign in to comment.