diff --git a/luisa_compute/src/lang/soa.rs b/luisa_compute/src/lang/soa.rs index 6fb6a77..6ce80fa 100644 --- a/luisa_compute/src/lang/soa.rs +++ b/luisa_compute/src/lang/soa.rs @@ -1,14 +1,139 @@ +use crate::internal_prelude::*; use crate::prelude::*; -use luisa_compute_ir::ir::Type; +use crate::runtime::submit_default_stream_and_sync; +use parking_lot::Mutex; +use std::ops::RangeBounds; use std::sync::Arc; +use crate::runtime::Kernel; + +use super::index::IntoIndex; use super::types::SoaValue; /** A buffer with SOA layout. */ pub struct SoaBuffer { - storage: Arc, - metadata: Buffer, - _marker: std::marker::PhantomData, + pub(crate) device: Device, + pub(crate) storage: Arc, + pub(crate) metadata_buf: Buffer, + pub(crate) metadata: SoaMetadata, + pub(crate) copy_kernel: Mutex>>, + pub(crate) _marker: std::marker::PhantomData, +} +pub(crate) struct SoaBufferCopyKernel { + copy_to: Kernel, Buffer, u64)>, + copy_from: Kernel, Buffer, u64)>, +} +impl SoaBufferCopyKernel { + #[tracked] + fn new(device: &Device) -> Self { + let copy_to = + device.create_kernel::, Buffer, u64)>(&|soa, buf, offset| { + let i = dispatch_id().x.as_u64() + offset; + let v = soa.read(i); + buf.write(i, v); + }); + let copy_from = + device.create_kernel::, Buffer, u64)>(&|soa, buf, offset| { + let i = dispatch_id().x.as_u64() + offset; + let v = buf.read(i); + soa.write(i, v); + }); + Self { copy_to, copy_from } + } +} +impl SoaBuffer { + pub fn var(&self) -> SoaBufferVar { + self.view(..).var() + } + pub fn len(&self) -> usize { + self.metadata.count as usize + } + pub fn len_expr(&self) -> Expr { + self.metadata_buf.read(0).count + } + pub fn view>(&self, range: S) -> SoaBufferView { + let lower = range.start_bound(); + let upper = range.end_bound(); + let lower = match lower { + std::ops::Bound::Included(&x) => x, + std::ops::Bound::Excluded(&x) => x + 1, + std::ops::Bound::Unbounded => 0, + }; + let upper = match upper { + std::ops::Bound::Included(&x) => x + 1, + std::ops::Bound::Excluded(&x) => x, + std::ops::Bound::Unbounded => self.len(), + }; + assert!(lower <= upper); + assert!(upper <= self.metadata.count as usize); + let metadata = SoaMetadata { + count: self.metadata.count, + view_start: lower as u64, + view_count: (upper - lower) as u64, + }; + SoaBufferView { + metadata_buf: self.device.create_buffer_from_slice(&[metadata]), + metadata, + buffer: self, + } + } + pub fn copy_from_buffer_async(&self, buffer: &Buffer) -> Command<'static, 'static> { + self.view(..).copy_from_buffer_async(buffer) + } + pub fn copy_from_buffer(&self, buffer: &Buffer) { + self.view(..).copy_from_buffer(buffer) + } + pub fn copy_to_buffer_async(&self, buffer: &Buffer) -> Command<'static, 'static> { + self.view(..).copy_to_buffer_async(buffer) + } + pub fn copy_to_buffer(&self, buffer: &Buffer) { + self.view(..).copy_to_buffer(buffer) + } +} +impl<'a, T: SoaValue> SoaBufferView<'a, T> { + fn init_copy_kernel(&self) { + let mut copy_kernel = self.buffer.copy_kernel.lock(); + if copy_kernel.is_none() { + *copy_kernel = Some(SoaBufferCopyKernel::new(&self.buffer.device)); + } + } + pub fn var(&self) -> SoaBufferVar { + SoaBufferVar { + proxy: T::SoaBuffer::from_soa_storage( + self.buffer.storage.var(), + self.metadata_buf.read(0), + 0, + ), + } + } + pub fn copy_from_buffer_async(&self, buffer: &Buffer) -> Command<'static, 'static> { + self.init_copy_kernel(); + let copy_kernel = self.buffer.copy_kernel.lock(); + let copy_kernel = copy_kernel.as_ref().unwrap(); + copy_kernel.copy_from.dispatch_async( + [self.metadata.view_count as u32, 1, 1], + self, + buffer, + &self.metadata.view_start, + ) + } + pub fn copy_from_buffer(&self, buffer: &Buffer) { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_buffer_async(buffer)]); + } + pub fn copy_to_buffer_async(&self, buffer: &Buffer) -> Command<'static, 'static> { + self.init_copy_kernel(); + let copy_kernel = self.buffer.copy_kernel.lock(); + let copy_kernel = copy_kernel.as_ref().unwrap(); + copy_kernel.copy_to.dispatch_async( + [self.metadata.view_count as u32, 1, 1], + self, + buffer, + &self.metadata.view_start, + ) + } + pub fn copy_to_buffer(&self, buffer: &Buffer) { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(buffer)]); + } } #[derive(Clone, Copy, Value)] #[repr(C)] @@ -19,28 +144,22 @@ pub struct SoaMetadata { pub view_start: u64, pub view_count: u64, } -pub(crate) struct SoaStorage { - data: Arc, -} pub struct SoaBufferView<'a, T: SoaValue> { - metadata: Buffer, - buffer: &'a SoaBuffer, + pub(crate) metadata_buf: Buffer, + pub(crate) metadata: SoaMetadata, + pub(crate) buffer: &'a SoaBuffer, } pub struct SoaBufferVar { - proxy: T::SoaBuffer, + pub(crate) proxy: T::SoaBuffer, } - -fn compute_number_of_32bits_buffers(ty: &Type) -> usize { - (ty.size() + 3) / 4 +impl IndexRead for SoaBufferVar { + type Element = T; + fn read(&self, i: I) -> Expr { + self.proxy.read(i) + } +} +impl IndexWrite for SoaBufferVar { + fn write>(&self, i: I, value: V) { + self.proxy.write(i, value) + } } - -// impl IndexRead for SoaBuffer -// where -// T: Value, -// { -// type Element = T; -// fn read(&self, i: I) -> Expr { -// let i = i.to_u64(); -// todo!() -// } -// } diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index 12897a9..2a92fac 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -20,6 +20,56 @@ impl ArrayNewExpr for [T; N] { impl_simple_expr_proxy!([T: Value, const N: usize] ArrayExpr[T, N] for [T; N]); impl_simple_var_proxy!([T: Value, const N: usize] ArrayVar[T, N] for [T; N]); impl_simple_atomic_ref_proxy!([T: Value, const N: usize] ArrayAtomicRef[T, N] for [T; N]); +#[derive(Clone)] +pub struct ArraySoa { + pub(crate) elems: Vec, + _marker: PhantomData<[T; N]>, +} +impl SoaValue for [T; N] { + type SoaBuffer = ArraySoa; +} +impl SoaBufferProxy for ArraySoa { + type Value = [T; N]; + fn from_soa_storage( + storage: ByteBufferVar, + meta: Expr, + global_offset: usize, + ) -> Self { + let elems = (0..N) + .map(|i| { + T::SoaBuffer::from_soa_storage( + storage.clone(), + meta, + global_offset + i * T::SoaBuffer::num_buffers(), + ) + }) + .collect::>(); + Self { + elems, + _marker: PhantomData, + } + } + fn num_buffers() -> usize { + T::SoaBuffer::num_buffers() * N + } +} +impl IndexRead for ArraySoa { + type Element = [T; N]; + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + let elems = (0..N).map(|j| self.elems[j].read(i)).collect::>(); + <[T; N]>::from_elems_expr(elems.try_into().unwrap_or_else(|_| unreachable!())) + } +} +impl IndexWrite for ArraySoa { + fn write>(&self, i: I, value: V) { + let i = i.to_u64(); + let value = value.as_expr(); + for j in 0..N { + self.elems[j].write(i, value.read(j as u64)); + } + } +} impl ArrayExpr { pub fn len(&self) -> Expr { (N as u32).expr() diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index a7bd886..3f3f05b 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -1,3 +1,9 @@ +use std::any::TypeId; + +use serde_json::value::Index; + +use crate::lang::soa::SoaBuffer; + use super::*; mod private { @@ -24,6 +30,7 @@ pub trait Primitive: private::Sealed + Copy + TypeOf + 'static { /** * This is the heart of SOA implementation. */ +#[derive(Clone)] pub struct PrimitiveSoaProxy { /// this soa view starts from (self.global_offset * self.count * 4) of the global bytebuffer /// Each primitive must be stored in a 4-aligned region, due to dx12 does not support access <4 aligned values @@ -36,23 +43,21 @@ pub struct PrimitiveSoaProxy { pub(crate) data: ByteBufferVar, _marker: std::marker::PhantomData, } -impl IndexRead for PrimitiveSoaProxy { - type Element = T; + +impl IndexRead for PrimitiveSoaProxy { + type Element = bool; #[tracked] fn read(&self, i: I) -> Expr { - let i = i.to_u64(); - if need_runtime_check() { - lc_assert!(i.lt(self.view_count)); - } - unsafe { - self.data.read_as::( + let v = unsafe { + self.data.read_as::( self.global_offset * self.count * 4 - + (self.view_start + i) * std::mem::size_of::() as u64, + + (self.view_start + i.to_u64()) * std::mem::size_of::() as u64, ) - } + }; + v.ne(0) } } -impl IndexWrite for PrimitiveSoaProxy { +impl IndexWrite for PrimitiveSoaProxy { #[tracked] fn write>( &self, @@ -61,19 +66,181 @@ impl IndexWrite for PrimitiveSoaProxy { ) { let i = i.to_u64(); let v = value.as_expr(); - if need_runtime_check() { - lc_assert!(i.lt(self.view_count)); - } unsafe { - self.data.write_as::( + self.data.write_as::( self.global_offset * self.count * 4 - + (self.view_start + i) * std::mem::size_of::() as u64, - v, + + (self.view_start + i) * std::mem::size_of::() as u64, + select(v, 1u32.expr(), 0u32.expr()), ); } } } -impl SoaBufferProxy for PrimitiveSoaProxy { +macro_rules! impl_prim_soa_16 { + ($T:ty) => { + impl IndexRead for PrimitiveSoaProxy<$T> { + type Element = $T; + #[tracked] + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + if need_runtime_check() { + lc_assert!(i.lt(self.view_count)); + } + + unsafe { + let v = self.data.read_as::( + self.global_offset * self.count * 4 + + (self.view_start + i) * std::mem::size_of::() as u64, + ); + let v = (v & 0xffff).as_u16(); + v.bitcast::<$T>() + } + } + } + impl IndexWrite for PrimitiveSoaProxy<$T> { + #[tracked] + fn write>( + &self, + i: I, + value: V, + ) { + let i = i.to_u64(); + let v = value.as_expr(); + if need_runtime_check() { + lc_assert!(i.lt(self.view_count)); + } + unsafe { + let v = v.bitcast::(); + let v = v.as_u32(); + self.data.write_as::( + self.global_offset * self.count * 4 + + (self.view_start + i) * std::mem::size_of::() as u64, + v, + ); + } + } + } + }; +} +macro_rules! impl_prim_soa_8 { + ($T:ty) => { + impl IndexRead for PrimitiveSoaProxy<$T> { + type Element = $T; + #[tracked] + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + if need_runtime_check() { + lc_assert!(i.lt(self.view_count)); + } + + unsafe { + let v = self.data.read_as::( + self.global_offset * self.count * 4 + + (self.view_start + i) * std::mem::size_of::() as u64, + ); + let v = (v & 0xff).as_u8(); + v.bitcast::<$T>() + } + } + } + impl IndexWrite for PrimitiveSoaProxy<$T> { + #[tracked] + fn write>( + &self, + i: I, + value: V, + ) { + let i = i.to_u64(); + let v = value.as_expr(); + if need_runtime_check() { + lc_assert!(i.lt(self.view_count)); + } + unsafe { + let v = v.bitcast::(); + let v = v.as_u32(); + self.data.write_as::( + self.global_offset * self.count * 4 + + (self.view_start + i) * std::mem::size_of::() as u64, + v, + ); + } + } + } + }; +} +macro_rules! impl_prim_soa { + ($T:ty) => { + impl IndexRead for PrimitiveSoaProxy<$T> { + type Element = $T; + #[tracked] + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + if need_runtime_check() { + lc_assert!(i.lt(self.view_count)); + } + + assert!(std::mem::align_of::<$T>() >= 4); + unsafe { + self.data.read_as::<$T>( + self.global_offset * self.count * 4 + + (self.view_start + i) * std::mem::size_of::<$T>() as u64, + ) + } + } + } + impl IndexWrite for PrimitiveSoaProxy<$T> { + #[tracked] + fn write>( + &self, + i: I, + value: V, + ) { + let i = i.to_u64(); + let v = value.as_expr(); + if need_runtime_check() { + lc_assert!(i.lt(self.view_count)); + } + unsafe { + self.data.write_as::<$T>( + self.global_offset * self.count * 4 + + (self.view_start + i) * std::mem::size_of::<$T>() as u64, + v, + ); + } + } + } + }; +} +impl_prim_soa_8!(u8); +impl_prim_soa_8!(i8); +impl_prim_soa_16!(u16); +impl_prim_soa_16!(i16); +impl_prim_soa_16!(f16); +impl_prim_soa!(f32); +impl_prim_soa!(f64); +impl_prim_soa!(i32); +impl_prim_soa!(i64); +impl_prim_soa!(u32); +impl_prim_soa!(u64); +#[allow(dead_code)] +#[allow(unreachable_code)] +fn check_soa_impl() { + let _bool: SoaBuffer = unimplemented!(); + let _f16: SoaBuffer = unimplemented!(); + let _f32: SoaBuffer = unimplemented!(); + let _f64: SoaBuffer = unimplemented!(); + let _i8: SoaBuffer = unimplemented!(); + let _i16: SoaBuffer = unimplemented!(); + let _i32: SoaBuffer = unimplemented!(); + let _i64: SoaBuffer = unimplemented!(); + let _u8: SoaBuffer = unimplemented!(); + let _u16: SoaBuffer = unimplemented!(); + let _u32: SoaBuffer = unimplemented!(); + let _u64: SoaBuffer = unimplemented!(); +} +impl SoaBufferProxy for PrimitiveSoaProxy +where + Self: IndexRead + IndexWrite, +{ type Value = T; fn from_soa_storage( storage: ByteBufferVar, @@ -103,7 +270,10 @@ impl Value for T { Expr::::from_node(node) } } -impl SoaValue for T { +impl SoaValue for T +where + PrimitiveSoaProxy: IndexWrite + IndexRead, +{ type SoaBuffer = PrimitiveSoaProxy; } diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 9cda3bb..ade93e6 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -35,7 +35,7 @@ pub mod prelude { }; pub use crate::lang::types::vector::swizzle::*; pub use crate::lang::types::vector::VectorExprProxy; - pub use crate::lang::types::{AsExpr, Expr, Value, Var}; + pub use crate::lang::types::{AsExpr, Expr, Value, Var, SoaValue}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; pub use crate::runtime::api::StreamTag; @@ -60,7 +60,7 @@ mod internal_prelude { }; pub(crate) use crate::lang::ops::Linear; pub(crate) use crate::lang::types::vector::alias::*; - pub(crate) use crate::lang::types::vector::*; + pub(crate) use crate::lang::types::{SoaBufferProxy, vector::*}; #[allow(unused_imports)] pub(crate) use crate::lang::{ check_index_lt_usize, ir, CallFuncTrait, Recorder, __compose, __extract, __insert, @@ -86,8 +86,7 @@ pub use {luisa_compute_backend as backend, luisa_compute_sys as sys}; use lazy_static::lazy_static; use luisa_compute_backend::Backend; -use parking_lot::lock_api::RawMutex as RawMutexTrait; -use parking_lot::{Mutex, RawMutex}; +use parking_lot::Mutex; use runtime::{Device, DeviceHandle, StreamHandle}; use std::collections::HashMap; use std::sync::Weak; diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 654d20f..8bf9b66 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -323,7 +323,9 @@ impl Drop for BufferHandle { #[derive(Clone, Copy)] pub struct BufferView<'a, T: Value> { pub(crate) buffer: &'a Buffer, + /// offset in #elements pub(crate) offset: usize, + /// length in #elements pub(crate) len: usize, } impl<'a, T: Value> BufferView<'a, T> { diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 8bdf0ac..8ce842f 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -16,7 +16,7 @@ use raw_window_handle::HasRawWindowHandle; use winit::window::Window; use crate::internal_prelude::*; -use crate::lang::soa::SoaBuffer; +use crate::lang::soa::{SoaBuffer, SoaBufferVar, SoaBufferView, SoaMetadata}; use crate::lang::types::SoaValue; use ir::{ CallableModule, CallableModuleRef, Capture, CpuCustomOp, KernelModule, Module, ModuleFlags, @@ -197,8 +197,25 @@ impl Device { buffer } pub fn create_soa_buffer(&self, count: usize) -> SoaBuffer { + assert!(count <= u32::MAX as usize, "count must be less than u32::MAX. This limitation may be removed in the future."); // let inner = self.create_byte_buffer(len) - todo!() + let num_buffers = T::SoaBuffer::num_buffers(); + let storage = Arc::new(self.create_byte_buffer(count * num_buffers * std::mem::size_of::())); + let metadata = SoaMetadata { + count: count as u64, + view_start: 0, + view_count: count as u64, + }; + let metadata_buf = self.create_buffer_from_slice(&[metadata]); + let buffer = SoaBuffer { + storage, + metadata_buf, + metadata, + _marker: PhantomData, + copy_kernel: Mutex::new(None), + device: self.clone(), + }; + buffer } pub fn create_buffer(&self, count: usize) -> Buffer { let name = self.name(); @@ -1014,10 +1031,18 @@ impl KernelArgEncoder { size: buffer.len * std::mem::size_of::(), })); } + pub fn soa_buffer(&mut self, buffer: &SoaBuffer) { + self.buffer(&buffer.storage); + self.buffer(&buffer.metadata_buf); + } + pub fn soa_buffer_view(&mut self, view: &SoaBufferView) { + self.buffer(view.buffer.storage.as_ref()); + self.buffer::(&view.buffer.metadata_buf); + } pub fn buffer_view(&mut self, buffer: &BufferView) { self.args.push(api::Argument::Buffer(api::BufferArgument { buffer: buffer.handle(), - offset: buffer.offset, + offset: buffer.offset* std::mem::size_of::(), size: buffer.len * std::mem::size_of::(), })); } @@ -1067,6 +1092,18 @@ impl KernelArg for Buffer { encoder.buffer(self); } } +impl KernelArg for SoaBuffer { + type Parameter = SoaBufferVar; + fn encode(&self, encoder: &mut KernelArgEncoder) { + encoder.soa_buffer(self); + } +} +impl<'a, T:SoaValue> KernelArg for SoaBufferView<'a, T> { + type Parameter = SoaBufferVar; + fn encode(&self, encoder: &mut KernelArgEncoder) { + encoder.soa_buffer_view(self); + } +} // impl KernelArg for ByteBuffer { // type Parameter = ByteBufferVar; // fn encode(&self, encoder: &mut KernelArgEncoder) { @@ -1396,7 +1433,12 @@ impl AsKernelArg for Buffer { impl<'a, T: Value> AsKernelArg for BufferView<'a, T> { type Output = Buffer; } - +impl AsKernelArg for SoaBuffer { + type Output = SoaBuffer; +} +impl<'a, T: SoaValue> AsKernelArg for SoaBufferView<'a, T> { + type Output = SoaBuffer; +} impl<'a, T: IoTexel> AsKernelArg for Tex2dView<'a, T> { type Output = Tex2d; } diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 283180b..72b1d38 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -1,3 +1,5 @@ +use crate::lang::soa::SoaMetadata; + use super::*; impl CallableParameter for Expr { @@ -129,7 +131,12 @@ impl KernelParameter for BufferVar { builder.buffer() } } - +impl KernelParameter for SoaBufferVar { + type Arg = SoaBuffer; + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.soa_buffer() + } +} // impl KernelParameter for ByteBufferVar { // type Arg = ByteBuffer; // fn def_param(builder: &mut KernelBuilder) -> Self { @@ -244,6 +251,13 @@ impl KernelBuilder { handle: None, } } + pub fn soa_buffer(&mut self) -> SoaBufferVar { + let storage = self.buffer::(); + let metadata = self.buffer::(); + SoaBufferVar{ + proxy: T::SoaBuffer::from_soa_storage(storage, metadata.read(0), 0) + } + } pub fn tex2d(&mut self) -> Tex2dVar { let node = new_node( __module_pools(), @@ -338,7 +352,6 @@ impl KernelBuilder { }) } - /// Don't use this directly /// See [`Callable`] for how to create a callable #[doc(hidden)] diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index ceab4f2..ecbadd7 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -896,6 +896,49 @@ fn buffer_u8() { kernel.dispatch([1024, 1, 1]); } #[test] +fn buffer_view_copy() { + let device = get_device(); + let n = 1024; + let buf = device.create_buffer::(n); + let first_half = buf.view(0..n / 2); + let second_half = buf.view(n / 2..); + first_half.fill_fn(|i| i as f32); + second_half.fill_fn(|i| -(i as f32)); + let data = buf.copy_to_vec(); + for i in 0..n { + if i < n / 2 { + assert_eq!(data[i], i as f32); + } else { + assert_eq!(data[i], -((i - n / 2) as f32)); + } + } +} +#[test] +fn buffer_view() { + let device = get_device(); + let n = 1024; + let buf = device.create_buffer::(n); + let first_half = buf.view(0..n / 2); + let second_half = buf.view(n / 2..); + let kernel = Kernel::, Buffer)>::new( + &device, + &track!(|a, b| { + let tid = dispatch_id().x; + a.write(tid, tid.as_f32()); + b.write(tid, -tid.as_f32()); + }), + ); + kernel.dispatch([n as u32 / 2, 1, 1], &first_half, &second_half); + let data = buf.copy_to_vec(); + for i in 0..n { + if i < n / 2 { + assert_eq!(data[i], i as f32); + } else { + assert_eq!(data[i], -((i - n / 2) as f32)); + } + } +} +#[test] fn byte_buffer() { let device = get_device(); let buf = device.create_byte_buffer(1024); @@ -1130,7 +1173,7 @@ fn is_nan() { // track!(*a + b) // } -#[derive(Clone, Copy, Debug, Value, PartialEq)] +#[derive(Clone, Copy, Debug, Value, Soa, PartialEq)] #[repr(C)] #[value_new(pub)] pub struct Foo { @@ -1138,7 +1181,23 @@ pub struct Foo { v: Float2, a: [i32; 4], } - +#[test] +fn soa() { + let device = get_device(); + let mut rng = thread_rng(); + let foos = device.create_buffer_from_fn(1024, |_| Foo { + i: rng.gen(), + v: Float2::new(rng.gen(), rng.gen()), + a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + }); + let foos_soa = device.create_soa_buffer::(1024); + foos_soa.copy_from_buffer(&foos); + let also_foos = device.create_buffer(1024); + foos_soa.copy_to_buffer(&also_foos); + let foos_data = foos.view(..).copy_to_vec(); + let also_foos_data = also_foos.view(..).copy_to_vec(); + assert_eq!(foos_data, also_foos_data); +} #[test] fn atomic() { let device = get_device(); diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index 95b5734..3887d6a 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -155,56 +155,54 @@ impl Compiler { let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); let soa_proxy_name = syn::Ident::new(&format!("{}Soa", name), name.span()); - quote_spanned!(span=>{ - #[repr(C)] - #[derive(Copy, Clone)] - pub struct #soa_proxy_name #generics #where_clause{ - #(#field_vis #field_names: <#field_types as #lang_path::SoaValue>::SoaBuffer),* + quote_spanned!(span=> + #[derive(Clone)] + #vis struct #soa_proxy_name #generics #where_clause{ + #(#field_vis #field_names: <#field_types as #lang_path::types::SoaValue>::SoaBuffer),* } - impl #impl_generics #lang_path::SoaValue for #name #ty_generics #where_clause{ + impl #impl_generics #lang_path::types::SoaValue for #name #ty_generics #where_clause{ type SoaBuffer = #soa_proxy_name #ty_generics; } - impl #impl_generics #lang_path::SoaBufferProxy for #soa_proxy_name #ty_generics #where_clause{ + impl #impl_generics #lang_path::types::SoaBufferProxy for #soa_proxy_name #ty_generics #where_clause{ type Value = #name #ty_generics; #[allow(unused_assignments)] fn from_soa_storage( - storage: ByteBufferVar, - meta: Expr, - global_offset: usize, + ___storage: ::luisa_compute::resource::ByteBufferVar, + ___meta: Expr<#lang_path::soa::SoaMetadata>, + ___global_offset: usize, ) -> Self { - use #lang_path::SoaBufferProxy; - let mut i = 0; + use #lang_path::types::SoaBufferProxy; + let mut ___i = 0usize; #( - let $field_names = T::SoaBuffer::from_soa_storage( - storage.clone(), - meta.clone(), - global_offset + i, + let #field_names = <#field_types as #lang_path::types::SoaValue>::SoaBuffer::from_soa_storage( + ___storage.clone(), + ___meta.clone(), + ___global_offset + ___i, ); - i += <#field_types::SoaBuffer as SoaBufferProxy>::num_buffers(); + ___i += <<#field_types as #lang_path::types::SoaValue>::SoaBuffer as #lang_path::types::SoaBufferProxy>::num_buffers(); )* Self{ #(#field_names),* } } fn num_buffers() -> usize { - [#( <#field_types as #lang_path::SoaValue>::SoaBuffer::num_buffers()),*].iter().sum() + [#( <#field_types as #lang_path::types::SoaValue>::SoaBuffer::num_buffers()),*].iter().sum() } } - impl #impl_generics #lang_path::IndexRead for #soa_proxy_name #ty_generics #where_clause{ + impl #impl_generics #lang_path::index::IndexRead for #soa_proxy_name #ty_generics #where_clause{ type Element = #name #ty_generics; - fn read(&self, i: I) -> Expr { - let i = i.to_u64(); + fn read(&self, ___i: I) -> #lang_path::types::Expr { + let ___i = ___i.to_u64(); + use #lang_path::FromNode; #( - let #field_names = self.#field_names.read(i); + let #field_names = self.#field_names.read(___i); )* - Self{ - #(#field_names),* - } + Expr::::from_node(#lang_path::__compose::(&[ #( #lang_path::ToNode::node(&#field_names) ),* ])) } } - impl #impl_generics #lang_path::IndexWrite for #soa_proxy_name #ty_generics #where_clause{ - fn write>(&self, i: I, value: V) { + impl #impl_generics #lang_path::index::IndexWrite for #soa_proxy_name #ty_generics #where_clause{ + fn write>(&self, i: I, value: V) { let i = i.to_u64(); let v = value.as_expr(); #( @@ -212,7 +210,7 @@ impl Compiler { )* } } - }) + ) } pub fn derive_value(&self, struct_: &ItemStruct) -> TokenStream { let ordering = self.value_attributes(&struct_.attrs); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 6887785..46d8999 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 688778557bf4fe6cb87f0d3cd3caf0dd3259d880 +Subproject commit 46d89997f6495156083d6cf2d138c50442c708d1