diff --git a/luisa_compute/src/lang/math.rs b/luisa_compute/src/lang/math.rs index ef86196..f474ec1 100644 --- a/luisa_compute/src/lang/math.rs +++ b/luisa_compute/src/lang/math.rs @@ -1,19 +1,19 @@ pub use super::swizzle::*; use super::{Aggregate, ExprProxy, Value, VarProxy, __extract, traits::*, Float}; use crate::*; -use serde::{Serialize, Deserialize}; use half::f16; use luisa_compute_ir::{ context::register_type, ir::{Func, MatrixType, NodeRef, Primitive, Type, VectorElementType, VectorType}, TypeOf, }; +use serde::{Deserialize, Serialize}; use std::ops::Mul; macro_rules! def_vec { ($name:ident, $glam_type:ident, $scalar:ty, $align:literal, $($comp:ident), *) => { #[repr(C, align($align))] - #[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)] + #[derive(Copy, Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct $name { $(pub $comp: $scalar), * } @@ -44,7 +44,7 @@ macro_rules! def_vec { macro_rules! def_packed_vec { ($name:ident, $vec_type:ident, $glam_type:ident, $scalar:ty, $($comp:ident), *) => { #[repr(C)] - #[derive(Copy, Clone, Debug, Default, __Value, Serialize, Deserialize)] + #[derive(Copy, Clone, Debug, Default, __Value,PartialEq, Serialize, Deserialize)] pub struct $name { $(pub $comp: $scalar), * } @@ -480,7 +480,7 @@ macro_rules! impl_vec_proxy { } } impl VectorVarTrait for $expr_proxy { } - impl ScalarOrVector for $expr_proxy { + impl ScalarOrVector for $expr_proxy { type Element = Expr<$scalar>; type ElementHost = $scalar; } diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index 1c83250..6044955 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -38,7 +38,7 @@ use math::Uint3; use std::cell::{Cell, RefCell, UnsafeCell}; use std::ffi::CString; use std::ops::{Bound, Deref, DerefMut, RangeBounds}; - +use std::sync::atomic::AtomicUsize; // use self::math::Uint3; pub mod math; pub mod poly; @@ -50,6 +50,14 @@ pub use math::*; pub use poly::*; pub use printer::*; +pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0); +// prevent node being shared across kernels +// TODO: replace NodeRef with SafeNodeRef +#[derive(Clone, Copy, Debug)] +pub(crate) struct SafeNodeRef { + pub(crate) node: NodeRef, + pub(crate) kernel_id: usize, +} pub trait Value: Copy + ir::TypeOf + 'static { type Expr: ExprProxy; type Var: VarProxy; @@ -128,18 +136,51 @@ macro_rules! impl_aggregate_for_tuple { } impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -pub unsafe trait _Mask: ToNode {} +pub unsafe trait Mask: ToNode {} +pub trait IntoIndex { + fn to_u64(&self) -> Expr; +} +impl IntoIndex for i32 { + fn to_u64(&self) -> Expr { + const_(*self as u64) + } +} +impl IntoIndex for i64 { + fn to_u64(&self) -> Expr { + const_(*self as u64) + } +} +impl IntoIndex for u32 { + fn to_u64(&self) -> Expr { + const_(*self as u64) + } +} +impl IntoIndex for u64 { + fn to_u64(&self) -> Expr { + const_(*self) + } +} +impl IntoIndex for PrimExpr { + fn to_u64(&self) -> Expr { + self.ulong() + } +} +impl IntoIndex for PrimExpr { + fn to_u64(&self) -> Expr { + *self + } +} pub trait IndexRead: ToNode { type Element: Value; - fn read>>(&self, i: I) -> Expr; + fn read(&self, i: I) -> Expr; } pub trait IndexWrite: IndexRead { - fn write>, V: Into>>(&self, i: I, value: V); + fn write>>(&self, i: I, value: V); } -pub fn select(mask: impl _Mask, a: A, b: A) -> A { +pub fn select(mask: impl Mask, a: A, b: A) -> A { let a_nodes = a.to_vec_nodes(); let b_nodes = b.to_vec_nodes(); assert_eq!(a_nodes.len(), b_nodes.len()); @@ -178,9 +219,9 @@ impl ToNode for bool { } } -unsafe impl _Mask for bool {} +unsafe impl Mask for bool {} -unsafe impl _Mask for Bool {} +unsafe impl Mask for Bool {} pub trait ExprProxy: Copy + Aggregate + FromNode { type Value: Value; @@ -553,6 +594,7 @@ impl CpuFn { pub(crate) struct Recorder { pub(crate) scopes: Vec, + pub(crate) kernel_id: Option, pub(crate) lock: bool, pub(crate) captured_buffer: IndexMap)>, pub(crate) cpu_custom_ops: IndexMap)>, @@ -576,6 +618,7 @@ impl Recorder { self.block_size = None; self.arena.reset(); self.shared.clear(); + self.kernel_id = None; } pub(crate) fn new() -> Self { Recorder { @@ -590,6 +633,7 @@ impl Recorder { pools: None, arena: Bump::new(), building_kernel: false, + kernel_id: None, } } } @@ -671,6 +715,15 @@ pub fn __module_pools() -> &'static CArc { unsafe { std::mem::transmute(pool) } }) } +// pub fn __load(node: NodeRef) -> Expr { +// __current_scope(|b| { +// let node = b.load(node); +// Expr::::from_node(node) +// }) +// } +// pub fn __store(var:NodeRef, value:NodeRef) { +// let inst = &var.get().instruction; +// } pub fn __extract(node: NodeRef, index: usize) -> NodeRef { let inst = &node.get().instruction; @@ -685,6 +738,14 @@ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { Func::GetElementPtr } } + Instruction::Call(f, args) => match f { + Func::AtomicRef => { + let mut indices = args.to_vec(); + indices.push(i); + return b.call(Func::AtomicRef, &indices, ::type_()); + } + _ => Func::ExtractElement, + }, _ => Func::ExtractElement, }; let node = b.call(op, &[node, i], ::type_()); @@ -759,6 +820,12 @@ macro_rules! var { ($t:ty, $init:expr) => { local::<$t>($init.into()) }; + ($e:expr) => { + def($e) + }; +} +pub fn def, T: Value>(init: E) -> Var { + Var::::from_node(__current_scope(|b| b.local(init.node()))) } pub fn local(init: Expr) -> Var { Var::::from_node(__current_scope(|b| b.local(init.node()))) @@ -1294,9 +1361,9 @@ impl Shared { }), } } - pub fn len(&self) -> Expr { + pub fn len(&self) -> Expr { match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => const_(*length as u32), + Type::Array(ArrayType { element: _, length }) => const_(*length as u64), _ => unreachable!(), } } @@ -1306,8 +1373,8 @@ impl Shared { _ => unreachable!(), } } - pub fn write>, V: Into>>(&self, i: I, value: V) { - let i = i.into(); + pub fn write>>(&self, i: I, value: V) { + let i = i.to_u64(); let value = value.into(); if need_runtime_check() { @@ -1467,8 +1534,8 @@ impl VLArrayExpr { _ => unreachable!(), } } - pub fn read>>(&self, i: I) -> Expr { - let i = i.into(); + pub fn read(&self, i: I) -> Expr { + let i = i.to_u64(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); } @@ -1477,9 +1544,9 @@ impl VLArrayExpr { b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) })) } - pub fn len(&self) -> Expr { + pub fn len(&self) -> Expr { match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => const_(*length as u32), + Type::Array(ArrayType { element: _, length }) => const_(*length as u64), _ => unreachable!(), } } @@ -1487,10 +1554,10 @@ impl VLArrayExpr { impl IndexRead for ArrayExpr { type Element = T; - fn read>>(&self, i: I) -> Expr { - let i = i.into(); + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); - lc_assert!(i.cmplt(const_(N as u32))); + lc_assert!(i.cmplt(const_(N as u64))); Expr::::from_node(__current_scope(|b| { b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) @@ -1500,10 +1567,10 @@ impl IndexRead for ArrayExpr { impl IndexRead for ArrayVar { type Element = T; - fn read>>(&self, i: I) -> Expr { - let i = i.into(); + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.cmplt(const_(N as u32))); + lc_assert!(i.cmplt(const_(N as u64))); } Expr::::from_node(__current_scope(|b| { @@ -1514,12 +1581,12 @@ impl IndexRead for ArrayVar { } impl IndexWrite for ArrayVar { - fn write>, V: Into>>(&self, i: I, value: V) { - let i = i.into(); + fn write>>(&self, i: I, value: V) { + let i = i.to_u64(); let value = value.into(); if need_runtime_check() { - lc_assert!(i.cmplt(const_(N as u32))); + lc_assert!(i.cmplt(const_(N as u64))); } __current_scope(|b| { @@ -1666,7 +1733,14 @@ impl CallableParameter for BufferVar { encoder.buffer(self) } } - +impl CallableParameter for ByteBufferVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.byte_buffer() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.byte_buffer(self) + } +} impl CallableParameter for Tex2dVar { fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { builder.tex2d() @@ -1716,7 +1790,11 @@ where builder.uniform::() } } - +impl KernelParameter for ByteBufferVar { + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.byte_buffer() + } +} impl KernelParameter for BufferVar { fn def_param(builder: &mut KernelBuilder) -> Self { builder.buffer() @@ -1810,6 +1888,17 @@ impl KernelBuilder { self.args.push(node); FromNode::from_node(node) } + pub fn byte_buffer(&mut self) -> ByteBufferVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Buffer), Type::void()), + ); + self.args.push(node); + ByteBufferVar { + node, + handle: None, + } + } pub fn buffer(&mut self) -> BufferVar { let node = new_node( __module_pools(), @@ -2288,7 +2377,7 @@ macro_rules! impl_kernel_build_for_fn { impl_kernel_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); pub fn if_then_else( - cond: impl _Mask, + cond: impl Mask, then: impl Fn() -> R, else_: impl Fn() -> R, ) -> R { diff --git a/luisa_compute/src/lang/printer.rs b/luisa_compute/src/lang/printer.rs index a3ab4ad..6b376b9 100644 --- a/luisa_compute/src/lang/printer.rs +++ b/luisa_compute/src/lang/printer.rs @@ -123,7 +123,7 @@ impl Printer { let item_id = items.len() as u32; if_!( - offset.cmplt(data.len()) & (offset + 1 + args.count as u32).cmple(data.len()), + offset.cmplt(data.len().uint()) & (offset + 1 + args.count as u32).cmple(data.len().uint()), { data.atomic_fetch_add(0, 1); data.write(offset, item_id); diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 3702781..c654c6d 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -19,7 +19,7 @@ pub mod prelude { pub use crate::lang::traits::{CommonVarOp, FloatVarTrait, IntVarTrait, VarCmp, VarCmpEq}; pub use crate::lang::{ Aggregate, ExprProxy, FromNode, IndexRead, IndexWrite, KernelBuildFn, KernelParameter, - KernelSignature, Value, VarProxy, _Mask, + KernelSignature, Value, VarProxy, Mask, }; pub use crate::lang::{ __compose, __cpu_dbg, __current_scope, __env_need_backtrace, __extract, __insert, diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 6f95a62..2425408 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -13,10 +13,230 @@ use std::process::abort; use std::sync::Arc; pub struct ByteBuffer { - inner: Buffer, + pub(crate) device: Device, + pub(crate) handle: Arc, + pub(crate) len: usize, +} +impl ByteBuffer { + pub fn len(&self) -> usize { + self.len + } + #[inline] + pub fn handle(&self) -> api::Buffer { + self.handle.handle + } + #[inline] + pub fn native_handle(&self) -> *mut c_void { + self.handle.native_handle + } + #[inline] + pub fn copy_from(&self, data: &[u8]) { + self.view(..).copy_from(data); + } + #[inline] + pub fn copy_from_async<'a>(&self, data: &[u8]) -> Command<'_> { + self.view(..).copy_from_async(data) + } + #[inline] + pub fn copy_to(&self, data: &mut [u8]) { + self.view(..).copy_to(data); + } + #[inline] + pub fn copy_to_async<'a>(&self, data: &'a mut [u8]) -> Command<'a> { + self.view(..).copy_to_async(data) + } + #[inline] + pub fn copy_to_vec(&self) -> Vec { + self.view(..).copy_to_vec() + } + #[inline] + pub fn copy_to_buffer(&self, dst: &ByteBuffer) { + self.view(..).copy_to_buffer(dst.view(..)); + } + #[inline] + pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a ByteBuffer) -> Command<'a> { + self.view(..).copy_to_buffer_async(dst.view(..)) + } + #[inline] + pub fn fill_fn u8>(&self, f: F) { + self.view(..).fill_fn(f); + } + #[inline] + pub fn fill(&self, value: u8) { + self.view(..).fill(value); + } + pub fn view>(&self, range: S) -> ByteBufferView<'_> { + let lower = range.start_bound(); + let upper = range.end_bound(); + let lower = match lower { + std::ops::Bound::Included(&x) => x, + std::ops::Bound::Excluded(&x) => x + 1, + std::ops::Bound::Unbounded => 0, + }; + let upper = match upper { + std::ops::Bound::Included(&x) => x + 1, + std::ops::Bound::Excluded(&x) => x, + std::ops::Bound::Unbounded => self.len, + }; + assert!(lower <= upper); + assert!(upper <= self.len); + ByteBufferView { + buffer: self, + offset: lower, + len: upper - lower, + } + } + pub fn var(&self) -> ByteBufferVar { + ByteBufferVar::new(&self.view(..)) + } +} +pub struct ByteBufferView<'a> { + pub(crate) buffer: &'a ByteBuffer, + pub(crate) offset: usize, + pub(crate) len: usize, +} +impl<'a> ByteBufferView<'a> { + pub fn handle(&self) -> api::Buffer { + self.buffer.handle() + } + pub fn copy_to_async<'b>(&'a self, data: &'b mut [u8]) -> Command<'b> { + assert_eq!(data.len(), self.len); + let mut rt = ResourceTracker::new(); + rt.add(self.buffer.handle.clone()); + Command { + inner: api::Command::BufferDownload(BufferDownloadCommand { + buffer: self.handle(), + offset: self.offset, + size: data.len(), + data: data.as_mut_ptr() as *mut u8, + }), + marker: std::marker::PhantomData, + resource_tracker: rt, + callback: None, + } + } + pub fn copy_to_vec(&self) -> Vec { + let mut data = Vec::with_capacity(self.len); + unsafe { + let slice = std::slice::from_raw_parts_mut(data.as_mut_ptr(), self.len); + self.copy_to(slice); + data.set_len(self.len); + } + data + } + pub fn copy_to(&self, data: &mut [u8]) { + unsafe { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_async(data)]); + } + } + + pub fn copy_from_async<'b>(&'a self, data: &'b [u8]) -> Command<'static> { + assert_eq!(data.len(), self.len); + let mut rt = ResourceTracker::new(); + rt.add(self.buffer.handle.clone()); + Command { + inner: api::Command::BufferUpload(BufferUploadCommand { + buffer: self.handle(), + offset: self.offset, + size: data.len(), + data: data.as_ptr() as *const u8, + }), + marker: std::marker::PhantomData, + resource_tracker: rt, + callback: None, + } + } + pub fn copy_from(&self, data: &[u8]) { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_async(data)]); + } + pub fn fill_fn u8>(&self, f: F) { + self.copy_from(&(0..self.len).map(f).collect::>()); + } + pub fn fill(&self, value: u8) { + self.fill_fn(|_| value); + } + pub fn copy_to_buffer_async(&self, dst: ByteBufferView<'a>) -> Command<'static> { + assert_eq!(self.len, dst.len); + let mut rt = ResourceTracker::new(); + rt.add(self.buffer.handle.clone()); + rt.add(dst.buffer.handle.clone()); + Command { + inner: api::Command::BufferCopy(api::BufferCopyCommand { + src: self.handle(), + src_offset: self.offset, + dst: dst.handle(), + dst_offset: dst.offset, + size: self.len, + }), + marker: std::marker::PhantomData, + resource_tracker: rt, + callback: None, + } + } + pub fn copy_to_buffer(&self, dst: ByteBufferView<'a>) { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(dst)]); + } } +#[derive(Clone)] pub struct ByteBufferVar { - inner: BufferVar, + #[allow(dead_code)] + pub(crate) handle: Option>, + pub(crate) node: NodeRef, +} +impl ByteBufferVar { + pub fn new(buffer: &ByteBufferView<'_>) -> Self { + let node = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!(r.lock, "BufferVar must be created from within a kernel"); + let binding = Binding::Buffer(BufferBinding { + handle: buffer.handle().0, + size: buffer.len, + offset: buffer.offset as u64, + }); + if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) { + *node + } else { + let node = new_node( + r.pools.as_ref().unwrap(), + Node::new(CArc::new(Instruction::Buffer), Type::void()), + ); + let i = r.captured_buffer.len(); + r.captured_buffer + .insert(binding, (i, node, binding, buffer.buffer.handle.clone())); + node + } + }); + Self { + node, + handle: Some(buffer.buffer.handle.clone()), + } + } + pub fn read(&self, index_bytes: impl IntoIndex) -> Expr { + let i = index_bytes.to_u64(); + Expr::::from_node(__current_scope(|b| { + b.call( + Func::ByteBufferRead, + &[self.node, i.node], + ::type_(), + ) + })) + } + pub fn len(&self) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::ByteBufferSize, &[self.node], ::type_()) + })) + } + pub fn write(&self, index_bytes: impl IntoIndex, value: impl Into>) { + let i = index_bytes.to_u64(); + let value: Expr = value.into(); + __current_scope(|b| { + b.call( + Func::ByteBufferWrite, + &[self.node, i.node, value.node()], + Type::void(), + ) + }); + } } pub struct Buffer { pub(crate) device: Device, @@ -47,7 +267,7 @@ impl<'a, T: Value> BufferView<'a, T> { pub fn var(&self) -> BufferVar { BufferVar::new(self) } - pub(crate) fn handle(&self) -> api::Buffer { + pub fn handle(&self) -> api::Buffer { self.buffer.handle() } pub fn copy_to_async<'b>(&'a self, data: &'b mut [T]) -> Command<'b> { @@ -130,7 +350,7 @@ impl<'a, T: Value> BufferView<'a, T> { } impl Buffer { #[inline] - pub(crate) fn handle(&self) -> api::Buffer { + pub fn handle(&self) -> api::Buffer { self.handle.handle } #[inline] @@ -182,17 +402,17 @@ impl Buffer { pub fn fill(&self, value: T) { self.view(..).fill(value); } - pub fn view>(&self, range: S) -> BufferView { + pub fn view>(&self, range: S) -> BufferView { let lower = range.start_bound(); let upper = range.end_bound(); let lower = match lower { - std::ops::Bound::Included(&x) => x as usize, - std::ops::Bound::Excluded(&x) => x as usize + 1, + std::ops::Bound::Included(&x) => x, + std::ops::Bound::Excluded(&x) => x + 1, std::ops::Bound::Unbounded => 0, }; let upper = match upper { - std::ops::Bound::Included(&x) => x as usize + 1, - std::ops::Bound::Excluded(&x) => x as usize, + std::ops::Bound::Included(&x) => x + 1, + std::ops::Bound::Excluded(&x) => x, std::ops::Bound::Unbounded => self.len, }; assert!(lower <= upper); @@ -200,7 +420,7 @@ impl Buffer { BufferView { buffer: self, offset: lower, - len: (upper - lower) as usize, + len: upper - lower, } } #[inline] @@ -267,7 +487,7 @@ impl BufferHeap { self.inner.emplace_buffer_async(index, buffer); } pub fn emplace_buffer_view_async<'a>(&self, index: usize, bufferview: &BufferView<'a, T>) { - self.inner.emplace_bufferview_async(index, bufferview); + self.inner.emplace_buffer_view_async(index, bufferview); } pub fn remove_buffer_async(&self, index: usize) { self.inner.remove_buffer_async(index); @@ -278,7 +498,7 @@ impl BufferHeap { } #[inline] pub fn emplace_buffer_view<'a>(&self, index: usize, bufferview: &BufferView<'a, T>) { - self.inner.emplace_bufferview_async(index, bufferview); + self.inner.emplace_buffer_view_async(index, bufferview); } #[inline] pub fn remove_buffer(&self, index: usize) { @@ -343,8 +563,14 @@ impl BindlessArray { pub fn native_handle(&self) -> *mut std::ffi::c_void { self.handle.native_handle } - - pub fn emplace_buffer_async(&self, index: usize, buffer: &Buffer) { + pub fn emplace_byte_buffer_async(&self, index: usize, buffer: &ByteBuffer) { + self.emplace_byte_buffer_view_async(index, &buffer.view(..)) + } + pub fn emplace_byte_buffer_view_async<'a>( + &self, + index: usize, + bufferview: &ByteBufferView<'a>, + ) { self.lock(); self.modifications .borrow_mut() @@ -352,18 +578,21 @@ impl BindlessArray { slot: index, buffer: api::BindlessArrayUpdateBuffer { op: api::BindlessArrayUpdateOperation::Emplace, - handle: buffer.handle.handle, - offset: 0, + handle: bufferview.handle(), + offset: bufferview.offset, }, tex2d: api::BindlessArrayUpdateTexture::default(), tex3d: api::BindlessArrayUpdateTexture::default(), }); self.make_pending_slots(); let mut pending = self.pending_slots.borrow_mut(); - pending[index].buffer = Some(buffer.handle.clone()); + pending[index].buffer = Some(bufferview.buffer.handle.clone()); self.unlock(); } - pub fn emplace_bufferview_async<'a, T: Value>( + pub fn emplace_buffer_async(&self, index: usize, buffer: &Buffer) { + self.emplace_buffer_view_async(index, &buffer.view(..)) + } + pub fn emplace_buffer_view_async<'a, T: Value>( &self, index: usize, bufferview: &BufferView<'a, T>, @@ -492,13 +721,23 @@ impl BindlessArray { self.unlock(); } #[inline] + pub fn emplace_byte_buffer(&self, index: usize, buffer: &ByteBuffer) { + self.emplace_byte_buffer_async(index, buffer); + self.update(); + } + #[inline] + pub fn emplace_byte_buffer_view(&self, index: usize, buffer: &ByteBufferView<'_>) { + self.emplace_byte_buffer_view_async(index, buffer); + self.update(); + } + #[inline] pub fn emplace_buffer(&self, index: usize, buffer: &Buffer) { self.emplace_buffer_async(index, buffer); self.update(); } #[inline] pub fn emplace_buffer_view(&self, index: usize, buffer: &BufferView) { - self.emplace_bufferview_async(index, buffer); + self.emplace_buffer_view_async(index, buffer); self.update(); } #[inline] @@ -789,7 +1028,7 @@ pub struct Tex3dView<'a, T: IoTexel> { pub(crate) level: u32, } impl Tex2d { - pub(crate) fn handle(&self) -> api::Texture { + pub fn handle(&self) -> api::Texture { self.handle.handle } pub fn var(&self) -> Tex2dVar { @@ -797,7 +1036,7 @@ impl Tex2d { } } impl Tex3d { - pub(crate) fn handle(&self) -> api::Texture { + pub fn handle(&self) -> api::Texture { self.handle.handle } pub fn var(&self) -> Tex3dVar { @@ -957,7 +1196,7 @@ macro_rules! impl_tex_view { }; } impl<'a, T: IoTexel> Tex2dView<'a, T> { - pub(crate) fn handle(&self) -> api::Texture { + pub fn handle(&self) -> api::Texture { self.tex.handle.handle } pub fn texel_count(&self) -> u32 { @@ -977,7 +1216,7 @@ impl<'a, T: IoTexel> Tex2dView<'a, T> { } impl_tex_view!(Tex2dView); impl<'a, T: IoTexel> Tex3dView<'a, T> { - pub(crate) fn handle(&self) -> api::Texture { + pub fn handle(&self) -> api::Texture { self.tex.handle.handle } pub fn texel_count(&self) -> u32 { @@ -1074,8 +1313,8 @@ impl ToNode for BindlessBufferVar { impl IndexRead for BindlessBufferVar { type Element = T; - fn read>>(&self, i: I) -> Expr { - let i = i.into(); + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); } @@ -1090,9 +1329,9 @@ impl IndexRead for BindlessBufferVar { } } impl BindlessBufferVar { - pub fn len(&self) -> Expr { - let stride = const_(T::type_().size() as u32); - Expr::::from_node(__current_scope(|b| { + pub fn len(&self) -> Expr { + let stride = const_(T::type_().size() as u64); + Expr::::from_node(__current_scope(|b| { b.call( Func::BindlessBufferSize, &[self.array, self.buffer_index.node(), stride.node()], @@ -1111,6 +1350,38 @@ impl BindlessBufferVar { } } #[derive(Clone)] +pub struct BindlessByteBufferVar { + array: NodeRef, + buffer_index: Expr, +} +impl ToNode for BindlessByteBufferVar { + fn node(&self) -> NodeRef { + self.array + } +} +impl BindlessByteBufferVar { + pub fn read(&self, index_bytes: impl IntoIndex) -> Expr { + let i = index_bytes.to_u64(); + Expr::::from_node(__current_scope(|b| { + b.call( + Func::BindlessByteAdressBufferRead, + &[self.array, self.buffer_index.node(), i.node], + ::type_(), + ) + })) + } + pub fn len(&self) -> Expr { + let s = const_(1u64); + Expr::::from_node(__current_scope(|b| { + b.call( + Func::BindlessBufferSize, + &[self.array, self.buffer_index.node(), s.node()], + ::type_(), + ) + })) + } +} +#[derive(Clone)] pub struct BindlessTex2dVar { array: NodeRef, tex2d_index: Expr, @@ -1300,6 +1571,16 @@ impl BindlessArrayVar { }; v } + pub fn byte_address_buffer( + &self, + buffer_index: impl Into>, + ) -> BindlessByteBufferVar { + let v = BindlessByteBufferVar { + array: self.node, + buffer_index: buffer_index.into(), + }; + v + } pub fn buffer(&self, buffer_index: impl Into>) -> BindlessBufferVar { let v = BindlessBufferVar { array: self.node, @@ -1369,19 +1650,19 @@ impl ToNode for Buffer { } impl IndexRead for Buffer { type Element = T; - fn read>>(&self, i: I) -> Expr { + fn read(&self, i: I) -> Expr { self.var().read(i) } } impl IndexWrite for Buffer { - fn write>, V: Into>>(&self, i: I, v: V) { + fn write>>(&self, i: I, v: V) { self.var().write(i, v) } } impl IndexRead for BufferVar { type Element = T; - fn read>>(&self, i: I) -> Expr { - let i = i.into(); + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); } @@ -1395,8 +1676,8 @@ impl IndexRead for BufferVar { } } impl IndexWrite for BufferVar { - fn write>, V: Into>>(&self, i: I, v: V) { - let i = i.into(); + fn write>>(&self, i: I, v: V) { + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1439,9 +1720,9 @@ impl BufferVar { handle: Some(buffer.buffer.handle.clone()), } } - pub fn len(&self) -> Expr { + pub fn len(&self) -> Expr { FromNode::from_node( - __current_scope(|b| b.call(Func::BufferSize, &[self.node], u32::type_())).into(), + __current_scope(|b| b.call(Func::BufferSize, &[self.node], u64::type_())).into(), ) } } @@ -1449,12 +1730,8 @@ impl BufferVar { macro_rules! impl_atomic { ($t:ty) => { impl BufferVar<$t> { - pub fn atomic_exchange>, V: Into>>( - &self, - i: I, - v: V, - ) -> Expr<$t> { - let i = i.into(); + pub fn atomic_exchange>>(&self, i: I, v: V) -> Expr<$t> { + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1467,17 +1744,13 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_compare_exchange< - I: Into>, - V0: Into>, - V1: Into>, - >( + pub fn atomic_compare_exchange>, V1: Into>>( &self, i: I, expected: V0, desired: V1, ) -> Expr<$t> { - let i = i.into(); + let i = i.to_u64(); let expected = expected.into(); let desired = desired.into(); if need_runtime_check() { @@ -1491,12 +1764,12 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_add>, V: Into>>( + pub fn atomic_fetch_add>>( &self, i: I, v: V, ) -> Expr<$t> { - let i = i.into(); + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1509,12 +1782,12 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_sub>, V: Into>>( + pub fn atomic_fetch_sub>>( &self, i: I, v: V, ) -> Expr<$t> { - let i = i.into(); + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1527,12 +1800,12 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_min>, V: Into>>( + pub fn atomic_fetch_min>>( &self, i: I, v: V, ) -> Expr<$t> { - let i = i.into(); + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1545,12 +1818,12 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_max>, V: Into>>( + pub fn atomic_fetch_max>>( &self, i: I, v: V, ) -> Expr<$t> { - let i = i.into(); + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1569,12 +1842,12 @@ macro_rules! impl_atomic { macro_rules! impl_atomic_bit { ($t:ty) => { impl BufferVar<$t> { - pub fn atomic_fetch_and>, V: Into>>( + pub fn atomic_fetch_and>>( &self, i: I, v: V, ) -> Expr<$t> { - let i = i.into(); + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1587,12 +1860,8 @@ macro_rules! impl_atomic_bit { ) })) } - pub fn atomic_fetch_or>, V: Into>>( - &self, - i: I, - v: V, - ) -> Expr<$t> { - let i = i.into(); + pub fn atomic_fetch_or>>(&self, i: I, v: V) -> Expr<$t> { + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); @@ -1605,12 +1874,12 @@ macro_rules! impl_atomic_bit { ) })) } - pub fn atomic_fetch_xor>, V: Into>>( + pub fn atomic_fetch_xor>>( &self, i: I, v: V, ) -> Expr<$t> { - let i = i.into(); + let i = i.to_u64(); let v = v.into(); if need_runtime_check() { lc_assert!(i.cmplt(self.len())); diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 3d225e8..75e33e2 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -142,6 +142,19 @@ impl Device { }; swapchain } + pub fn create_byte_buffer(&self, len: usize) -> ByteBuffer { + let buffer = self.inner.create_buffer(&Type::void(), len); + let buffer = ByteBuffer { + device: self.clone(), + handle: Arc::new(BufferHandle { + device: self.clone(), + handle: api::Buffer(buffer.resource.handle), + native_handle: buffer.resource.native_handle, + }), + len, + }; + buffer + } pub fn create_buffer(&self, count: usize) -> Buffer { assert!( std::mem::size_of::() > 0, @@ -910,6 +923,9 @@ impl CallableArgEncoder { pub fn buffer(&mut self, buffer: &BufferVar) { self.args.push(buffer.node); } + pub fn byte_buffer(&mut self, buffer: &ByteBufferVar) { + self.args.push(buffer.node); + } pub fn tex2d(&mut self, tex2d: &Tex2dVar) { self.args.push(tex2d.node); } @@ -973,6 +989,20 @@ impl KernelArgEncoder { size: buffer.len * std::mem::size_of::(), })); } + pub fn byte_buffer(&mut self, buffer: &ByteBuffer) { + self.args.push(api::Argument::Buffer(api::BufferArgument { + buffer: buffer.handle.handle, + offset: 0, + size: buffer.len, + })); + } + pub fn byte_buffer_view(&mut self, buffer: &ByteBufferView) { + self.args.push(api::Argument::Buffer(api::BufferArgument { + buffer: buffer.handle(), + offset: buffer.offset, + size: buffer.len, + })); + } pub fn tex2d(&mut self, tex: &Tex2dView) { self.args.push(api::Argument::Texture(api::TextureArgument { texture: tex.handle(), @@ -1005,7 +1035,18 @@ impl KernelArg for Buffer { encoder.buffer(self); } } - +impl KernelArg for ByteBuffer { + type Parameter = ByteBufferVar; + fn encode(&self, encoder: &mut KernelArgEncoder) { + encoder.byte_buffer(self); + } +} +impl<'a> KernelArg for ByteBufferView<'a> { + type Parameter = ByteBufferVar; + fn encode(&self, encoder: &mut KernelArgEncoder) { + encoder.byte_buffer_view(self); + } +} impl KernelArg for T { type Parameter = Expr; fn encode(&self, encoder: &mut KernelArgEncoder) { @@ -1234,6 +1275,14 @@ impl<'a, T: Value> AsKernelArg> for BufferView<'a, T> {} impl<'a, T: Value> AsKernelArg> for Buffer {} +impl AsKernelArg for ByteBuffer {} + +impl<'a> AsKernelArg for ByteBufferView<'a> {} + +impl<'a> AsKernelArg> for ByteBufferView<'a> {} + +impl<'a> AsKernelArg> for ByteBuffer {} + impl<'a, T: IoTexel> AsKernelArg> for Tex2dView<'a, T> {} impl<'a, T: IoTexel> AsKernelArg> for Tex3dView<'a, T> {} diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 4e980a7..5dcc060 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -15,8 +15,14 @@ fn _signal_handler(signal: libc::c_int) { } static ONCE: std::sync::Once = std::sync::Once::new(); fn get_device() -> Device { - ONCE.call_once(||{ - // init_logger(); + let show_log = match std::env::var("LUISA_TEST_LOG") { + Ok(log) => log == "1", + Err(_) => false, + }; + ONCE.call_once(|| { + if show_log { + init_logger_verbose(); + } unsafe { libc::signal(libc::SIGSEGV, _signal_handler as usize); } @@ -85,32 +91,30 @@ fn autodiff_helper Float>( // inputs[i].view(..).copy_from(&tmp); // } println!("init time: {:?}", tic.elapsed()); - let kernel = device - .create_kernel_async::<()>(&|| { - let input_vars = inputs.iter().map(|input| input.var()).collect::>(); - let grad_fd_vars = grad_fd.iter().map(|grad| grad.var()).collect::>(); - let grad_ad_vars = grad_ad.iter().map(|grad| grad.var()).collect::>(); - let tid = dispatch_id().x(); - let inputs = input_vars - .iter() - .map(|input| input.read(tid)) - .collect::>(); - autodiff(|| { - for input in &inputs { - requires_grad(*input); - } - let output = f(&inputs); - backward(output); - for i in 0..n_inputs { - grad_ad_vars[i].write(tid, gradient(inputs[i])); - } - }); - let fd = finite_difference(&inputs, &f); + let kernel = device.create_kernel_async::<()>(&|| { + let input_vars = inputs.iter().map(|input| input.var()).collect::>(); + let grad_fd_vars = grad_fd.iter().map(|grad| grad.var()).collect::>(); + let grad_ad_vars = grad_ad.iter().map(|grad| grad.var()).collect::>(); + let tid = dispatch_id().x(); + let inputs = input_vars + .iter() + .map(|input| input.read(tid)) + .collect::>(); + autodiff(|| { + for input in &inputs { + requires_grad(*input); + } + let output = f(&inputs); + backward(output); for i in 0..n_inputs { - grad_fd_vars[i].write(tid, fd[i]); + grad_ad_vars[i].write(tid, gradient(inputs[i])); } - }) - ; + }); + let fd = finite_difference(&inputs, &f); + for i in 0..n_inputs { + grad_fd_vars[i].write(tid, fd[i]); + } + }); let tic = std::time::Instant::now(); kernel.dispatch([repeats as u32, 1, 1]); println!("kernel time: {:?}", tic.elapsed()); @@ -614,25 +618,23 @@ fn autodiff_select() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device - .create_kernel::<()>(&|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = select(x.cmpgt(y), x * 4.0, y * 0.5); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - }) - ; + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = select(x.cmpgt(y), x * 4.0, y * 0.5); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -660,26 +662,24 @@ fn autodiff_detach() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device - .create_kernel::<()>(&|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let k = detach(x * y); - let z = (x + y) * k; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - }) - ; + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let k = detach(x * y); + let z = (x + y) * k; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -712,27 +712,25 @@ fn autodiff_select_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device - .create_kernel::<()>(&|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let cond = x.cmpgt(y); - let a = (x - y).sqrt(); - let z = select(cond, a, y * 0.5); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - }) - ; + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let cond = x.cmpgt(y); + let a = (x - y).sqrt(); + let z = select(cond, a, y * 0.5); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -755,32 +753,30 @@ fn autodiff_if_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device - .create_kernel::<()>(&|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let cond = x.cmpgt(y); - let z = if_!(cond, { - let a = (x - y).sqrt(); - a - }, else { - y * 0.5 - }); - // cpu_dbg!(f32, z); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let cond = x.cmpgt(y); + let z = if_!(cond, { + let a = (x - y).sqrt(); + a + }, else { + y * 0.5 }); - }) - ; + // cpu_dbg!(f32, z); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -808,31 +804,29 @@ fn autodiff_if_phi() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device - .create_kernel::<()>(&|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - if_!(true, { - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = if_!(x.cmpgt(y), { - x * 4.0 - }, else { - y * 0.5 - }); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + if_!(true, { + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = if_!(x.cmpgt(y), { + x * 4.0 + }, else { + y * 0.5 }); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); }); - }) - ; + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -860,33 +854,31 @@ fn autodiff_if_phi2() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device - .create_kernel::<()>(&|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = if_!(x.cmpgt(y), { - if_!(x.cmpgt(3.0), { - x * 4.0 - }, else { - x * 2.0 - }) + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = if_!(x.cmpgt(y), { + if_!(x.cmpgt(3.0), { + x * 4.0 }, else { - y * 0.5 - }); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); + x * 2.0 + }) + }, else { + y * 0.5 }); - }) - ; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -918,33 +910,31 @@ fn autodiff_if_phi3() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device - .create_kernel::<()>(&|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let c = x.cmpgt(3.0).int(); - let z = if_!(x.cmpgt(y), { - switch::>(c) - .case(0, || x * 2.0) - .default(|| x * 4.0) - .finish() * 2.0 - }, else { - y * 0.5 - }); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let c = x.cmpgt(3.0).int(); + let z = if_!(x.cmpgt(y), { + switch::>(c) + .case(0, || x * 2.0) + .default(|| x * 4.0) + .finish() * 2.0 + }, else { + y * 0.5 }); - }) - ; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -978,31 +968,29 @@ fn autodiff_switch() { t.view(..).fill_fn(|_| rng.gen_range(0..3)); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device - .create_kernel::<()>(&|| { - let buf_t = t.var(); - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let t = buf_t.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = switch::>(t) - .case(0, || x * 4.0) - .case(1, || x * 2.0) - .case(2, || y * 0.5) - .finish(); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - }) - ; + let kernel = device.create_kernel::<()>(&|| { + let buf_t = t.var(); + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let t = buf_t.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = switch::>(t) + .case(0, || x * 4.0) + .case(1, || x * 2.0) + .case(2, || y * 0.5) + .finish(); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 138d2b8..8b03733 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -13,7 +13,14 @@ fn _signal_handler(signal: libc::c_int) { } static ONCE: std::sync::Once = std::sync::Once::new(); fn get_device() -> Device { + let show_log = match std::env::var("LUISA_TEST_LOG") { + Ok(log) => log == "1", + Err(_) => false, + }; ONCE.call_once(|| unsafe { + if show_log { + init_logger_verbose(); + } libc::signal(libc::SIGSEGV, _signal_handler as usize); }); let curr_exe = current_exe().unwrap(); @@ -651,3 +658,147 @@ fn uniform() { let expected = (x.len() as f32 - 1.0) * x.len() as f32 * 0.5 * 6.0; assert!((actual - expected).abs() < 1e-4); } +#[derive(Clone, Copy, Debug, __Value)] +#[repr(C)] +struct Big { + a: [f32; 32], +} +#[test] +fn byte_buffer() { + let device = get_device(); + let buf = device.create_byte_buffer(1024); + let mut big = Big { a: [1.0; 32] }; + for i in 0..32 { + big.a[i] = i as f32; + } + let mut cnt = 0usize; + macro_rules! push { + ($t:ty, $v:expr) => {{ + let old = cnt; + let s = std::mem::size_of::<$t>(); + let view = buf.view(cnt..cnt + s); + let bytes = unsafe { std::slice::from_raw_parts(&$v as *const $t as *const u8, s) }; + view.copy_from(bytes); + cnt += s; + old + }}; + } + let i0 = push!(Float3, Float3::new(0.0, 0.0, 0.0)); + let i1 = push!(Big, big); + let i2 = push!(i32, 0i32); + let i3 = push!(f32, 1f32); + device + .create_kernel::<()>(&|| { + let buf = buf.var(); + let i0 = i0 as u64; + let i1 = i1 as u64; + let i2 = i2 as u64; + let i3 = i3 as u64; + let v0 = def(buf.read::(i0)); + let v1 = def(buf.read::(i1)); + let v2 = def(buf.read::(i2)); + let v3 = def(buf.read::(i3)); + *v0.get_mut() = make_float3(1.0, 2.0, 3.0); + for_range(0u32..32u32, |i| { + v1.a().write(i, i.float() * 2.0); + }); + *v2.get_mut() = 1i32.into(); + *v3.get_mut() = 2.0.into(); + buf.write::(i0, v0.load()); + buf.write::(i1, v1.load()); + buf.write::(i2, v2.load()); + buf.write::(i3, v3.load()); + }) + .dispatch([1, 1, 1]); + let data = buf.copy_to_vec(); + macro_rules! pop { + ($t:ty, $offset:expr) => {{ + let s = std::mem::size_of::<$t>(); + let bytes = &data[$offset..$offset + s]; + let v = unsafe { std::mem::transmute_copy::<[u8; {std::mem::size_of::<$t>()}], $t>(bytes.try_into().unwrap()) }; + v + }}; + } + let v0 = pop!(Float3, i0); + let v1 = pop!(Big, i1); + let v2 = pop!(i32, i2); + let v3 = pop!(f32, i3); + assert_eq!(v0, Float3::new(1.0,2.0,3.0)); + assert_eq!(v2, 1); + assert_eq!(v3, 2.0); + for i in 0..32 { + assert!(v1.a[i] == i as f32 * 2.0); + } +} + +#[test] +fn bindless_byte_buffer() { + let device = get_device(); + let buf = device.create_byte_buffer(1024); + let out = device.create_byte_buffer(1024); + let mut big = Big { a: [1.0; 32] }; + for i in 0..32 { + big.a[i] = i as f32; + } + let heap = device.create_bindless_array(64); + heap.emplace_byte_buffer(0, &buf); + let mut cnt = 0usize; + macro_rules! push { + ($t:ty, $v:expr) => {{ + let old = cnt; + let s = std::mem::size_of::<$t>(); + let view = buf.view(cnt..cnt + s); + let bytes = unsafe { std::slice::from_raw_parts(&$v as *const $t as *const u8, s) }; + view.copy_from(bytes); + cnt += s; + old + }}; + } + let i0 = push!(Float3, Float3::new(0.0, 0.0, 0.0)); + let i1 = push!(Big, big); + let i2 = push!(i32, 0i32); + let i3 = push!(f32, 1f32); + device + .create_kernel::<(ByteBuffer,)>(&|out:ByteBufferVar| { + let heap = heap.var(); + let buf = heap.byte_address_buffer(0); + let i0 = i0 as u64; + let i1 = i1 as u64; + let i2 = i2 as u64; + let i3 = i3 as u64; + let v0 = def(buf.read::(i0)); + let v1 = def(buf.read::(i1)); + let v2 = def(buf.read::(i2)); + let v3 = def(buf.read::(i3)); + *v0.get_mut() = make_float3(1.0, 2.0, 3.0); + for_range(0u32..32u32, |i| { + v1.a().write(i, i.float() * 2.0); + }); + *v2.get_mut() = 1i32.into(); + *v3.get_mut() = 2.0.into(); + out.write::(i0, v0.load()); + out.write::(i1, v1.load()); + out.write::(i2, v2.load()); + out.write::(i3, v3.load()); + }) + .dispatch([1, 1, 1], &out); + let data = out.copy_to_vec(); + macro_rules! pop { + ($t:ty, $offset:expr) => {{ + let s = std::mem::size_of::<$t>(); + let bytes = &data[$offset..$offset + s]; + let v = unsafe { std::mem::transmute_copy::<[u8; {std::mem::size_of::<$t>()}], $t>(bytes.try_into().unwrap()) }; + v + }}; + } + let v0 = pop!(Float3, i0); + let v1 = pop!(Big, i1); + let v2 = pop!(i32, i2); + let v3 = pop!(f32, i3); + assert_eq!(v0, Float3::new(1.0,2.0,3.0)); + assert_eq!(v2, 1); + assert_eq!(v3, 2.0); + for i in 0..32 { + assert!(v1.a[i] == i as f32 * 2.0); + } +} diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 75d0324..515d62b 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 75d0324176a321926f52041b5db21c386481702e +Subproject commit 515d62b67911fffd0289e0aa3715df1eebde3f89