Skip to content

Commit

Permalink
add SoaBuffer<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 7, 2023
1 parent 74e87b4 commit 775485b
Show file tree
Hide file tree
Showing 10 changed files with 536 additions and 84 deletions.
167 changes: 143 additions & 24 deletions luisa_compute/src/lang/soa.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,139 @@
use crate::internal_prelude::*;
use crate::prelude::*;
use luisa_compute_ir::ir::Type;
use crate::runtime::submit_default_stream_and_sync;
use parking_lot::Mutex;
use std::ops::RangeBounds;
use std::sync::Arc;

use crate::runtime::Kernel;

use super::index::IntoIndex;
use super::types::SoaValue;
/** A buffer with SOA layout.
*/
pub struct SoaBuffer<T: SoaValue> {
storage: Arc<ByteBuffer>,
metadata: Buffer<SoaMetadata>,
_marker: std::marker::PhantomData<T>,
pub(crate) device: Device,
pub(crate) storage: Arc<ByteBuffer>,
pub(crate) metadata_buf: Buffer<SoaMetadata>,
pub(crate) metadata: SoaMetadata,
pub(crate) copy_kernel: Mutex<Option<SoaBufferCopyKernel<T>>>,
pub(crate) _marker: std::marker::PhantomData<T>,
}
pub(crate) struct SoaBufferCopyKernel<T: SoaValue> {
copy_to: Kernel<fn(SoaBuffer<T>, Buffer<T>, u64)>,
copy_from: Kernel<fn(SoaBuffer<T>, Buffer<T>, u64)>,
}
impl<T: SoaValue> SoaBufferCopyKernel<T> {
#[tracked]
fn new(device: &Device) -> Self {
let copy_to =
device.create_kernel::<fn(SoaBuffer<T>, Buffer<T>, u64)>(&|soa, buf, offset| {
let i = dispatch_id().x.as_u64() + offset;
let v = soa.read(i);
buf.write(i, v);
});
let copy_from =
device.create_kernel::<fn(SoaBuffer<T>, Buffer<T>, u64)>(&|soa, buf, offset| {
let i = dispatch_id().x.as_u64() + offset;
let v = buf.read(i);
soa.write(i, v);
});
Self { copy_to, copy_from }
}
}
impl<T: SoaValue> SoaBuffer<T> {
pub fn var(&self) -> SoaBufferVar<T> {
self.view(..).var()
}
pub fn len(&self) -> usize {
self.metadata.count as usize
}
pub fn len_expr(&self) -> Expr<u64> {
self.metadata_buf.read(0).count
}
pub fn view<S: RangeBounds<usize>>(&self, range: S) -> SoaBufferView<T> {
let lower = range.start_bound();
let upper = range.end_bound();
let lower = match lower {
std::ops::Bound::Included(&x) => x,
std::ops::Bound::Excluded(&x) => x + 1,
std::ops::Bound::Unbounded => 0,
};
let upper = match upper {
std::ops::Bound::Included(&x) => x + 1,
std::ops::Bound::Excluded(&x) => x,
std::ops::Bound::Unbounded => self.len(),
};
assert!(lower <= upper);
assert!(upper <= self.metadata.count as usize);
let metadata = SoaMetadata {
count: self.metadata.count,
view_start: lower as u64,
view_count: (upper - lower) as u64,
};
SoaBufferView {
metadata_buf: self.device.create_buffer_from_slice(&[metadata]),
metadata,
buffer: self,
}
}
pub fn copy_from_buffer_async(&self, buffer: &Buffer<T>) -> Command<'static, 'static> {
self.view(..).copy_from_buffer_async(buffer)
}
pub fn copy_from_buffer(&self, buffer: &Buffer<T>) {
self.view(..).copy_from_buffer(buffer)
}
pub fn copy_to_buffer_async(&self, buffer: &Buffer<T>) -> Command<'static, 'static> {
self.view(..).copy_to_buffer_async(buffer)
}
pub fn copy_to_buffer(&self, buffer: &Buffer<T>) {
self.view(..).copy_to_buffer(buffer)
}
}
impl<'a, T: SoaValue> SoaBufferView<'a, T> {
fn init_copy_kernel(&self) {
let mut copy_kernel = self.buffer.copy_kernel.lock();
if copy_kernel.is_none() {
*copy_kernel = Some(SoaBufferCopyKernel::new(&self.buffer.device));
}
}
pub fn var(&self) -> SoaBufferVar<T> {
SoaBufferVar {
proxy: T::SoaBuffer::from_soa_storage(
self.buffer.storage.var(),
self.metadata_buf.read(0),
0,
),
}
}
pub fn copy_from_buffer_async(&self, buffer: &Buffer<T>) -> Command<'static, 'static> {
self.init_copy_kernel();
let copy_kernel = self.buffer.copy_kernel.lock();
let copy_kernel = copy_kernel.as_ref().unwrap();
copy_kernel.copy_from.dispatch_async(
[self.metadata.view_count as u32, 1, 1],
self,
buffer,
&self.metadata.view_start,
)
}
pub fn copy_from_buffer(&self, buffer: &Buffer<T>) {
submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_buffer_async(buffer)]);
}
pub fn copy_to_buffer_async(&self, buffer: &Buffer<T>) -> Command<'static, 'static> {
self.init_copy_kernel();
let copy_kernel = self.buffer.copy_kernel.lock();
let copy_kernel = copy_kernel.as_ref().unwrap();
copy_kernel.copy_to.dispatch_async(
[self.metadata.view_count as u32, 1, 1],
self,
buffer,
&self.metadata.view_start,
)
}
pub fn copy_to_buffer(&self, buffer: &Buffer<T>) {
submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(buffer)]);
}
}
#[derive(Clone, Copy, Value)]
#[repr(C)]
Expand All @@ -19,28 +144,22 @@ pub struct SoaMetadata {
pub view_start: u64,
pub view_count: u64,
}
pub(crate) struct SoaStorage {
data: Arc<ByteBuffer>,
}
pub struct SoaBufferView<'a, T: SoaValue> {
metadata: Buffer<SoaMetadata>,
buffer: &'a SoaBuffer<T>,
pub(crate) metadata_buf: Buffer<SoaMetadata>,
pub(crate) metadata: SoaMetadata,
pub(crate) buffer: &'a SoaBuffer<T>,
}
pub struct SoaBufferVar<T: SoaValue> {
proxy: T::SoaBuffer,
pub(crate) proxy: T::SoaBuffer,
}

fn compute_number_of_32bits_buffers(ty: &Type) -> usize {
(ty.size() + 3) / 4
impl<T: SoaValue> IndexRead for SoaBufferVar<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<Self::Element> {
self.proxy.read(i)
}
}
impl<T: SoaValue> IndexWrite for SoaBufferVar<T> {
fn write<I: IntoIndex, V: AsExpr<Value = Self::Element>>(&self, i: I, value: V) {
self.proxy.write(i, value)
}
}

// impl<T> IndexRead for SoaBuffer<T>
// where
// T: Value,
// {
// type Element = T;
// fn read<I: super::index::IntoIndex>(&self, i: I) -> Expr<Self::Element> {
// let i = i.to_u64();
// todo!()
// }
// }
50 changes: 50 additions & 0 deletions luisa_compute/src/lang/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,56 @@ impl<T: Value, const N: usize> ArrayNewExpr<T, N> for [T; N] {
impl_simple_expr_proxy!([T: Value, const N: usize] ArrayExpr[T, N] for [T; N]);
impl_simple_var_proxy!([T: Value, const N: usize] ArrayVar[T, N] for [T; N]);
impl_simple_atomic_ref_proxy!([T: Value, const N: usize] ArrayAtomicRef[T, N] for [T; N]);
#[derive(Clone)]
pub struct ArraySoa<T: SoaValue, const N: usize> {
pub(crate) elems: Vec<T::SoaBuffer>,
_marker: PhantomData<[T; N]>,
}
impl<T: SoaValue, const N: usize> SoaValue for [T; N] {
type SoaBuffer = ArraySoa<T, N>;
}
impl<T: SoaValue, const N: usize> SoaBufferProxy for ArraySoa<T, N> {
type Value = [T; N];
fn from_soa_storage(
storage: ByteBufferVar,
meta: Expr<SoaMetadata>,
global_offset: usize,
) -> Self {
let elems = (0..N)
.map(|i| {
T::SoaBuffer::from_soa_storage(
storage.clone(),
meta,
global_offset + i * T::SoaBuffer::num_buffers(),
)
})
.collect::<Vec<_>>();
Self {
elems,
_marker: PhantomData,
}
}
fn num_buffers() -> usize {
T::SoaBuffer::num_buffers() * N
}
}
impl<T: SoaValue, const N: usize> IndexRead for ArraySoa<T, N> {
type Element = [T; N];
fn read<I: IntoIndex>(&self, i: I) -> Expr<Self::Element> {
let i = i.to_u64();
let elems = (0..N).map(|j| self.elems[j].read(i)).collect::<Vec<_>>();
<[T; N]>::from_elems_expr(elems.try_into().unwrap_or_else(|_| unreachable!()))
}
}
impl<T: SoaValue, const N: usize> IndexWrite for ArraySoa<T, N> {
fn write<I: IntoIndex, V: AsExpr<Value = Self::Element>>(&self, i: I, value: V) {
let i = i.to_u64();
let value = value.as_expr();
for j in 0..N {
self.elems[j].write(i, value.read(j as u64));
}
}
}
impl<T: Value, const N: usize> ArrayExpr<T, N> {
pub fn len(&self) -> Expr<u32> {
(N as u32).expr()
Expand Down
Loading

0 comments on commit 775485b

Please sign in to comment.