From d2be70699b046cea93c3bc5a5d2bc0977702d742 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sun, 17 Sep 2023 09:30:44 -0400 Subject: [PATCH] fix return_v() --- luisa_compute/examples/callable.rs | 14 ++++- luisa_compute/examples/callable_advanced.rs | 14 ++++- luisa_compute/src/lang/mod.rs | 47 ++++++++++++-- luisa_compute/tests/misc.rs | 69 +++++++++++++++++++-- luisa_compute_sys/LuisaCompute | 2 +- 5 files changed, 132 insertions(+), 14 deletions(-) diff --git a/luisa_compute/examples/callable.rs b/luisa_compute/examples/callable.rs index 52a88b9..1029317 100644 --- a/luisa_compute/examples/callable.rs +++ b/luisa_compute/examples/callable.rs @@ -8,7 +8,19 @@ fn main() { use luisa::*; init_logger(); let ctx = Context::new(current_exe().unwrap()); - let device = ctx.create_device("cpu"); + let args: Vec = std::env::args().collect(); + assert!( + args.len() <= 2, + "Usage: {} . : cpu, cuda, dx, metal, remote", + args[0] + ); + + let ctx = Context::new(current_exe().unwrap()); + let device = ctx.create_device(if args.len() == 2 { + args[1].as_str() + } else { + "cpu" + }); let add = device.create_callable::, Expr)->Expr>(&|a, b| a + b); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); diff --git a/luisa_compute/examples/callable_advanced.rs b/luisa_compute/examples/callable_advanced.rs index 057ac49..74a0e61 100644 --- a/luisa_compute/examples/callable_advanced.rs +++ b/luisa_compute/examples/callable_advanced.rs @@ -5,7 +5,19 @@ fn main() { use luisa::*; init_logger(); let ctx = Context::new(current_exe().unwrap()); - let device = ctx.create_device("cpu"); + let args: Vec = std::env::args().collect(); + assert!( + args.len() <= 2, + "Usage: {} . : cpu, cuda, dx, metal, remote", + args[0] + ); + + let ctx = Context::new(current_exe().unwrap()); + let device = ctx.create_device(if args.len() == 2 { + args[1].as_str() + } else { + "cpu" + }); let add = device.create_dyn_callable:: DynExpr>(Box::new( |a: DynExpr, b: DynExpr| -> DynExpr { if let Some(a) = a.downcast::() { diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index 16fa16f..b56a078 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -605,6 +605,7 @@ pub(crate) struct Recorder { pub(crate) building_kernel: bool, pub(crate) pools: Option>, pub(crate) arena: Bump, + pub(crate) callable_ret_type: Option>, } impl Recorder { @@ -619,6 +620,7 @@ impl Recorder { self.arena.reset(); self.shared.clear(); self.kernel_id = None; + self.callable_ret_type = None; } pub(crate) fn new() -> Self { Recorder { @@ -634,6 +636,7 @@ impl Recorder { arena: Bump::new(), building_kernel: false, kernel_id: None, + callable_ret_type: None, } } } @@ -2009,6 +2012,14 @@ impl KernelBuilder { let mut r = r.borrow_mut(); assert!(r.lock); r.lock = false; + if let Some(t) = &r.callable_ret_type { + assert!( + luisa_compute_ir::context::is_type_equal(t, &ret_type), + "Return type mismatch" + ); + } else { + r.callable_ret_type = Some(ret_type.clone()); + } assert_eq!(r.scopes.len(), 1); let scope = r.scopes.pop().unwrap(); let entry = scope.finish(); @@ -2732,12 +2743,38 @@ pub fn continue_() { }); } -// pub fn return_v(v: T) { -// __current_scope(|b| { -// b.return_(Some(v.node())); -// }); -// } +pub fn return_v(v: T) { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + if r.callable_ret_type.is_none() { + r.callable_ret_type = Some(v.node().type_().clone()); + } else { + assert!( + luisa_compute_ir::context::is_type_equal( + r.callable_ret_type.as_ref().unwrap(), + v.node().type_() + ), + "return type mismatch" + ); + } + }); + __current_scope(|b| { + b.return_(v.node()); + }); +} + pub fn return_() { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + if r.callable_ret_type.is_none() { + r.callable_ret_type = Some(Type::void()); + } else { + assert!(luisa_compute_ir::context::is_type_equal( + r.callable_ret_type.as_ref().unwrap(), + &Type::void() + )); + } + }); __current_scope(|b| { b.return_(INVALID_REF); }); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 5d57835..752e240 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -83,6 +83,55 @@ fn event() { assert_eq!(v[0], (1 + 3) * (4 + 5)); } #[test] +#[should_panic] +fn callable_return_mismatch() { + let device = get_device(); + let _abs = device.create_callable::) -> Expr>(&|x| { + if_!(x.cmpgt(0.0), { + return_v(const_(true)); + }); + -x + }); +} +#[test] +#[should_panic] +fn callable_return_void_mismatch() { + let device = get_device(); + let _abs = device.create_callable::)>(&|x| { + if_!(x.cmpgt(0.0), { + return_v(const_(true)); + }); + x.store(-*x); + }); +} +#[test] +fn callable_early_return() { + let device = get_device(); + let abs = device.create_callable::) -> Expr>(&|x| { + if_!(x.cmpgt(0.0), { + return_v(x); + }); + -x + }); + let x = device.create_buffer::(1024); + let mut rng = StdRng::seed_from_u64(0); + x.fill_fn(|_| rng.gen()); + let y = device.create_buffer::(1024); + device + .create_kernel::(&|| { + let i = dispatch_id().x(); + let x = x.var().read(i); + let y = y.var(); + y.write(i, abs.call(x)); + }) + .dispatch([x.len() as u32, 1, 1]); + let x = x.copy_to_vec(); + let y = y.copy_to_vec(); + for i in 0..x.len() { + assert_eq!(y[i], x[i].abs()); + } +} +#[test] fn callable() { let device = get_device(); let write = device.create_callable::, Expr, Var)>( @@ -91,7 +140,7 @@ fn callable() { v.store(v.load() + 1); }, ); - let add = device.create_callable::, Expr)->Expr>(&|a, b| a + b); + let add = device.create_callable::, Expr) -> Expr>(&|a, b| a + b); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); @@ -715,7 +764,11 @@ fn byte_buffer() { ($t:ty, $offset:expr) => {{ let s = std::mem::size_of::<$t>(); let bytes = &data[$offset..$offset + s]; - let v = unsafe { std::mem::transmute_copy::<[u8; {std::mem::size_of::<$t>()}], $t>(bytes.try_into().unwrap()) }; + let v = unsafe { + std::mem::transmute_copy::<[u8; { std::mem::size_of::<$t>() }], $t>( + bytes.try_into().unwrap(), + ) + }; v }}; } @@ -723,7 +776,7 @@ fn byte_buffer() { let v1 = pop!(Big, i1); let v2 = pop!(i32, i2); let v3 = pop!(f32, i3); - assert_eq!(v0, Float3::new(1.0,2.0,3.0)); + assert_eq!(v0, Float3::new(1.0, 2.0, 3.0)); assert_eq!(v2, 1); assert_eq!(v3, 2.0); for i in 0..32 { @@ -759,7 +812,7 @@ fn bindless_byte_buffer() { let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); device - .create_kernel::(&|out:ByteBufferVar| { + .create_kernel::(&|out: ByteBufferVar| { let heap = heap.var(); let buf = heap.byte_address_buffer(0); let i0 = i0 as u64; @@ -787,7 +840,11 @@ fn bindless_byte_buffer() { ($t:ty, $offset:expr) => {{ let s = std::mem::size_of::<$t>(); let bytes = &data[$offset..$offset + s]; - let v = unsafe { std::mem::transmute_copy::<[u8; {std::mem::size_of::<$t>()}], $t>(bytes.try_into().unwrap()) }; + let v = unsafe { + std::mem::transmute_copy::<[u8; { std::mem::size_of::<$t>() }], $t>( + bytes.try_into().unwrap(), + ) + }; v }}; } @@ -795,7 +852,7 @@ fn bindless_byte_buffer() { let v1 = pop!(Big, i1); let v2 = pop!(i32, i2); let v3 = pop!(f32, i3); - assert_eq!(v0, Float3::new(1.0,2.0,3.0)); + assert_eq!(v0, Float3::new(1.0, 2.0, 3.0)); assert_eq!(v2, 1); assert_eq!(v3, 2.0); for i in 0..32 { diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 739c69a..e318a82 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 739c69a3ba44d94ecd116c91e6b991125a942506 +Subproject commit e318a8209c0978daf22d686d2fdf2681eb98ccc4