diff --git a/README.md b/README.md index 0ee0589..4449388 100644 --- a/README.md +++ b/README.md @@ -298,14 +298,14 @@ let result = my_add.call(args); ``` ### Callable -Users can define device-only functions using Callables. Callables have similar type signature to kernels: `Callable`. +Users can define device-only functions using Callables. Callables have similar type signature to kernels: `CallableRet>`. The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using `Device::create_callable`. To invoke a Callable, use `Callable::call(args...)`. Callables accepts arguments such as resources (`BufferVar`, .etc), expressions and references (pass a `Var` to the callable). For example: ```rust -let add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| { +let add = device.create_callable::, Expr)-> Expr>(&|a, b| { a + b }); let z = add.call(x, y); -let pass_by_ref = device.create_callable::<(Var,), ()>(&|a| { +let pass_by_ref = device.create_callable::)>(&|a| { *a.get_mut() += 1.0; }); let a = var!(f32, 1.0); @@ -314,9 +314,9 @@ cpu_dbg!(*a); // prints 2.0 ``` ***Note***: You cannot record a callable when recording another kernel or callables. This is because a callable can capture outer variables such as buffers. However, capturing local variables define in another callable is undefined behavior. To avoid this, we disallow recording a callable when recording another callable or kernel. ```rust -let add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| { +let add = device.create_callable::, Expr)-> Expr>(&|a, b| { // runtime error! - let another_add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| { + let another_add = device.create_callable::, Expr)-> Expr>(&|a, b| { a + b }); a + b @@ -327,7 +327,7 @@ let add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| { 1. Use static callables. A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use `create_static_callable(fn)`. For example, ```rust lazy_static! { - static ref ADD:Callable<(Expr, Expr), Expr> = create_static_callable::<(Expr, Expr), Expr>(|a, b| { + static ref ADD:Callable, Expr)->Expr> = create_static_callable::, Expr)->Expr>(|a, b| { a + b }); } @@ -337,9 +337,9 @@ ADD.call(x, y); 2. Use `DynCallable`. These are callables that defer recording until being called. As a result, it requires you to pass a `'static` closure, avoiding the capture issue. To create a `DynCallable`, use `Device::create_dyn_callable(Box::new(fn))`. The syntax is the same as `create_callable`. Furthermore, `DynCallable` supports `DynExpr` and `DynVar`, which provides some capablitiy of implementing template/overloading inside EDSL. ```rust -let add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| { +let add = device.create_callable::, Expr)->Expr>(&|a, b| { // no error! - let another_add = device.create_dyn_callable::<(Expr, Expr), Expr>(Box::new(|a, b| { + let another_add = device.create_dyn_callable::, Expr)->Expr>(Box::new(|a, b| { a + b })); a + b @@ -349,10 +349,9 @@ let add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| { ### Kernel A kernel can be written in a closure or a function. The closure/function should have a `Fn(/*args*/)->()` signature, where the args are taking the `Var` type of resources, such as `BufferVar`, `Tex2D`, etc. -Note: `Device::create_kernel` takes a tuple of types as its generic parameter. If the kernel takes a single argument, it is required to use `create_kernel::<(Type,)>` instead of `create_kernel::`. ```rust -let kernel = device.create_kernel::<(Arg0, Arg1, ...)>(&|/*args*/| { +let kernel = device.create_kernel::(&|/*args*/| { /*body*/ }); kernel.dispatch([/*dispatch size*/], &arg0, &arg1, ...); @@ -360,7 +359,7 @@ kernel.dispatch([/*dispatch size*/], &arg0, &arg1, ...); There are two ways to pass arguments to a kernel: by arguments or by capture. ```rust let captured:Buffer = device.create_buffer(...); -let kernel = device.create_kernel::<(BufferVar, )>(arg| { +let kernel = device.create_kernel::>(arg| { let v = arg.read(..); let u = captured.var().read(..); })); @@ -372,7 +371,7 @@ pub struct BufferPair { a:Buffer, b:Buffer } -let kernel = device.create_kernel::<(BufferPair, )>(&|| { +let kernel = device.create_kernel::(&|| { // ... }); let a = device.create_buffer(...); diff --git a/luisa_compute/examples/atomic.rs b/luisa_compute/examples/atomic.rs index 68614ff..f686e00 100644 --- a/luisa_compute/examples/atomic.rs +++ b/luisa_compute/examples/atomic.rs @@ -11,7 +11,7 @@ fn main() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::<()>(&|| { + let shader = device.create_kernel::(&|| { let buf_x = x.var(); let buf_sum = sum.var(); let tid = luisa::dispatch_id().x(); diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 56f3629..ca8188a 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -23,7 +23,7 @@ fn main() { let dy = device.create_buffer::(1024); x.fill_fn(|i| i as f32); y.fill_fn(|i| 1.0 + i as f32); - let shader = device.create_kernel::<(Buffer, Buffer, Buffer, Buffer)>( + let shader = device.create_kernel::, Buffer, Buffer, Buffer)>( &|buf_x: BufferVar, buf_y: BufferVar, buf_dx: BufferVar, diff --git a/luisa_compute/examples/backtrace.rs b/luisa_compute/examples/backtrace.rs index 4153c50..373f054 100644 --- a/luisa_compute/examples/backtrace.rs +++ b/luisa_compute/examples/backtrace.rs @@ -13,7 +13,7 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::<(Buffer,)>(&|buf_z| { + let kernel = device.create_kernel::)>(&|buf_z| { // z is pass by arg let buf_x = x.var(); // x and y are captured let buf_y = y.var(); diff --git a/luisa_compute/examples/bindgroup.rs b/luisa_compute/examples/bindgroup.rs index c84b8fb..c330331 100644 --- a/luisa_compute/examples/bindgroup.rs +++ b/luisa_compute/examples/bindgroup.rs @@ -22,6 +22,6 @@ fn main() { y, exclude: 42.0, }; - let shader = device.create_kernel::<(MyArgStruct,)>(&|_args| {}); + let shader = device.create_kernel::)>(&|_args| {}); shader.dispatch([1024, 1, 1], &my_args); } diff --git a/luisa_compute/examples/bindless.rs b/luisa_compute/examples/bindless.rs index 93ae4bb..bd4d719 100644 --- a/luisa_compute/examples/bindless.rs +++ b/luisa_compute/examples/bindless.rs @@ -60,7 +60,7 @@ fn main() { bindless.emplace_buffer_async(1, &y); bindless.emplace_tex2d_async(0, &img, Sampler::default()); bindless.update(); - let kernel = device.create_kernel::<(BufferView,)>(&|buf_z| { + let kernel = device.create_kernel::)>(&|buf_z| { let bindless = bindless.var(); let tid = dispatch_id().x(); let buf_x = bindless.buffer::(Uint::from(0)); diff --git a/luisa_compute/examples/callable.rs b/luisa_compute/examples/callable.rs index 45dee32..52a88b9 100644 --- a/luisa_compute/examples/callable.rs +++ b/luisa_compute/examples/callable.rs @@ -9,13 +9,13 @@ fn main() { init_logger(); let ctx = Context::new(current_exe().unwrap()); let device = ctx.create_device("cpu"); - let add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| a + b); + let add = device.create_callable::, Expr)->Expr>(&|a, b| a + b); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::<(Buffer,)>(&|buf_z| { + let kernel = device.create_kernel::)>(&|buf_z| { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x(); diff --git a/luisa_compute/examples/callable_advanced.rs b/luisa_compute/examples/callable_advanced.rs index d790d81..057ac49 100644 --- a/luisa_compute/examples/callable_advanced.rs +++ b/luisa_compute/examples/callable_advanced.rs @@ -6,7 +6,7 @@ fn main() { init_logger(); let ctx = Context::new(current_exe().unwrap()); let device = ctx.create_device("cpu"); - let add = device.create_dyn_callable::<(DynExpr, DynExpr), DynExpr>(Box::new( + let add = device.create_dyn_callable:: DynExpr>(Box::new( |a: DynExpr, b: DynExpr| -> DynExpr { if let Some(a) = a.downcast::() { let b = b.downcast::().unwrap(); @@ -25,7 +25,7 @@ fn main() { let w = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::<(Buffer,)>(&|buf_z| { + let kernel = device.create_kernel::)>(&|buf_z| { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x(); diff --git a/luisa_compute/examples/custom_op.rs b/luisa_compute/examples/custom_op.rs index 74a2e64..8931ab2 100644 --- a/luisa_compute/examples/custom_op.rs +++ b/luisa_compute/examples/custom_op.rs @@ -30,7 +30,7 @@ fn main() { } }); let shader = device - .create_kernel::<(Buffer,)>(&|buf_z: BufferVar| { + .create_kernel::)>(&|buf_z: BufferVar| { // z is pass by arg let buf_x = x.var(); // x and y are captured let buf_y = y.var(); diff --git a/luisa_compute/examples/find_leak.rs b/luisa_compute/examples/find_leak.rs index 0dae34f..c8d573b 100644 --- a/luisa_compute/examples/find_leak.rs +++ b/luisa_compute/examples/find_leak.rs @@ -25,7 +25,7 @@ fn main() { let z = device.create_buffer::(count); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::<(Buffer, Buffer, Buffer)>( + let kernel = device.create_kernel::, Buffer, Buffer)>( &|buf_x, buf_y, buf_z| { let tid = dispatch_id().x(); let x = buf_x.read(tid); @@ -47,7 +47,7 @@ fn main() { let z = device.create_buffer::(count); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::<(Buffer,)>(&|buf_z| { + let kernel = device.create_kernel::)>(&|buf_z| { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x(); diff --git a/luisa_compute/examples/fluid.rs b/luisa_compute/examples/fluid.rs index 5aaacfc..6901b2a 100644 --- a/luisa_compute/examples/fluid.rs +++ b/luisa_compute/examples/fluid.rs @@ -114,7 +114,7 @@ fn main() { }; let advect = device - .create_kernel_async::<(Buffer, Buffer, Buffer, Buffer)>( + .create_kernel_async::, Buffer, Buffer, Buffer)>( &|u0, u1, rho0, rho1| { let coord = dispatch_id().xy(); let u = u0.read(index(coord)); @@ -129,7 +129,7 @@ fn main() { }, ); - let divergence = device.create_kernel_async::<(Buffer, Buffer)>(&|u, div| { + let divergence = device.create_kernel_async::, Buffer)>(&|u, div| { let coord = dispatch_id().xy(); if_!(coord.x().cmplt(N_GRID - 1) & coord.y().cmplt(N_GRID - 1), { let dx = (u.read(index(make_uint2(coord.x() + 1, coord.y()))).x() @@ -143,7 +143,7 @@ fn main() { }); let pressure_solve = - device.create_kernel_async::<(Buffer, Buffer, Buffer)>(&|p0, p1, div| { + device.create_kernel_async::, Buffer, Buffer)>(&|p0, p1, div| { let coord = dispatch_id().xy(); let i = coord.x().int(); let j = coord.y().int(); @@ -159,7 +159,7 @@ fn main() { p1.write(ij, err * 0.25f32); }); - let pressure_apply = device.create_kernel_async::<(Buffer, Buffer)>(&|p, u| { + let pressure_apply = device.create_kernel_async::, Buffer)>(&|p, u| { let coord = dispatch_id().xy(); let i = coord.x().int(); let j = coord.y().int(); @@ -181,7 +181,7 @@ fn main() { ); }); - let integrate = device.create_kernel_async::<(Buffer, Buffer)>(&|u, rho| { + let integrate = device.create_kernel_async::, Buffer)>(&|u, rho| { let coord = dispatch_id().xy(); let ij = index(coord); @@ -196,7 +196,7 @@ fn main() { }); let init = - device.create_kernel_async::<(Buffer, Buffer, Float2)>(&|rho, u, dir| { + device.create_kernel_async::, Buffer, Float2)>(&|rho, u, dir| { let coord = dispatch_id().xy(); let i = coord.x().int(); let j = coord.y().int(); @@ -210,7 +210,7 @@ fn main() { }); }); - let init_grid = device.create_kernel_async::<()>(&|| { + let init_grid = device.create_kernel_async::(&|| { let idx = index(dispatch_id().xy()); u0.var().write(idx, make_float2(0.0f32, 0.0f32)); u1.var().write(idx, make_float2(0.0f32, 0.0f32)); @@ -223,13 +223,13 @@ fn main() { div.var().write(idx, 0.0f32); }); - let clear_pressure = device.create_kernel_async::<()>(&|| { + let clear_pressure = device.create_kernel_async::(&|| { let idx = index(dispatch_id().xy()); p0.var().write(idx, 0.0f32); p1.var().write(idx, 0.0f32); }); - let draw_rho = device.create_kernel_async::<()>(&|| { + let draw_rho = device.create_kernel_async::(&|| { let coord = dispatch_id().xy(); let ij = index(coord); let value = rho0.var().read(ij); diff --git a/luisa_compute/examples/mpm.rs b/luisa_compute/examples/mpm.rs index af28938..08b2420 100644 --- a/luisa_compute/examples/mpm.rs +++ b/luisa_compute/examples/mpm.rs @@ -93,14 +93,14 @@ fn main() { p.x() + p.y() * N_GRID as u32 }; - let clear_grid = device.create_kernel_async::<()>(&|| { + let clear_grid = device.create_kernel_async::(&|| { let idx = index(dispatch_id().xy()); grid_v.var().write(idx * 2, 0.0f32); grid_v.var().write(idx * 2 + 1, 0.0f32); grid_m.var().write(idx, 0.0f32); }); - let point_to_grid = device.create_kernel_async::<()>(&|| { + let point_to_grid = device.create_kernel_async::(&|| { let p = dispatch_id().x(); let xp = x.var().read(p) / DX; let base = (xp - 0.5f32).int(); @@ -128,7 +128,7 @@ fn main() { } }); - let simulate_grid = device.create_kernel_async::<()>(&|| { + let simulate_grid = device.create_kernel_async::(&|| { let coord = dispatch_id().xy(); let i = index(coord); let v = var!(Float2); @@ -157,7 +157,7 @@ fn main() { grid_v.var().write(i * 2 + 1, vy); }); - let grid_to_point = device.create_kernel_async::<()>(&|| { + let grid_to_point = device.create_kernel_async::(&|| { let p = dispatch_id().x(); let xp = x.var().read(p) / DX; let base = (xp - 0.5f32).int(); @@ -192,13 +192,13 @@ fn main() { C.var().write(p, new_C); }); - let clear_display = device.create_kernel_async::<()>(&|| { + let clear_display = device.create_kernel_async::(&|| { display.var().write( dispatch_id().xy(), make_float4(0.1f32, 0.2f32, 0.3f32, 1.0f32), ); }); - let draw_particles = device.create_kernel_async::<()>(&|| { + let draw_particles = device.create_kernel_async::(&|| { let p = dispatch_id().x(); for i in -1..=1 { for j in -1..=1 { diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index dae9afc..49199b7 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -247,7 +247,7 @@ fn main() { // use create_kernel_async to compile multiple kernels in parallel let path_tracer = device - .create_kernel_async::<(Tex2d, Tex2d, Accel, Uint2)>( + .create_kernel_async::, Tex2d, Accel, Uint2)>( &|image: Tex2dVar, seed_image: Tex2dVar, accel: AccelVar, @@ -265,7 +265,7 @@ fn main() { ]); let lcg = |state: Var| -> Expr { - let lcg = create_static_callable::<(Var,), Expr>(|state:Var|{ + let lcg = create_static_callable::)-> Expr>(|state:Var|{ const LCG_A: u32 = 1664525u32; const LCG_C: u32 = 1013904223u32; *state.get_mut() = LCG_A * *state + LCG_C; @@ -441,7 +441,7 @@ fn main() { }, ) ; - let display = device.create_kernel_async::<(Tex2d, Tex2d)>(&|acc, display| { + let display = device.create_kernel_async::, Tex2d)>(&|acc, display| { set_block_size([16, 16, 1]); let coord = dispatch_id().xy(); let radiance = acc.read(coord); diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index aed14ef..0813423 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -253,7 +253,7 @@ fn main() { // use create_kernel_async to compile multiple kernels in parallel let path_tracer = device - .create_kernel_async::<(Tex2d, Tex2d, Accel, Uint2)>( + .create_kernel_async::, Tex2d, Accel, Uint2)>( &|image: Tex2dVar, seed_image: Tex2dVar, accel: AccelVar, @@ -271,7 +271,7 @@ fn main() { ]); let lcg = |state: Var| -> Expr { - let lcg = create_static_callable::<(Var, ), Expr>(|state: Var| { + let lcg = create_static_callable::)->Expr>(|state: Var| { const LCG_A: u32 = 1664525u32; const LCG_C: u32 = 1013904223u32; *state.get_mut() = LCG_A * *state + LCG_C; @@ -470,7 +470,7 @@ fn main() { }, ) ; - let display = device.create_kernel_async::<(Tex2d, Tex2d)>(&|acc, display| { + let display = device.create_kernel_async::, Tex2d)>(&|acc, display| { set_block_size([16, 16, 1]); let coord = dispatch_id().xy(); let radiance = acc.read(coord); diff --git a/luisa_compute/examples/polymorphism.rs b/luisa_compute/examples/polymorphism.rs index 7933166..2b447c2 100644 --- a/luisa_compute/examples/polymorphism.rs +++ b/luisa_compute/examples/polymorphism.rs @@ -50,7 +50,7 @@ fn main() { poly_area.register((), &circles); poly_area.register((), &squares); let areas = device.create_buffer::(4); - let shader = device.create_kernel::<()>(&|| { + let shader = device.create_kernel::(&|| { let tid = dispatch_id().x(); let tag = tid / 2; let index = tid % 2; diff --git a/luisa_compute/examples/polymorphism_advanced.rs b/luisa_compute/examples/polymorphism_advanced.rs index 3232d2e..d9cbc54 100644 --- a/luisa_compute/examples/polymorphism_advanced.rs +++ b/luisa_compute/examples/polymorphism_advanced.rs @@ -132,7 +132,7 @@ fn main() { ); let poly_shader = builder.build(); let result = device.create_buffer::(100); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let i = dispatch_id().x(); let x = i.float() / 100.0 * PI; let ctx = ShaderEvalContext { diff --git a/luisa_compute/examples/printer.rs b/luisa_compute/examples/printer.rs index 32d367a..9ef91cc 100644 --- a/luisa_compute/examples/printer.rs +++ b/luisa_compute/examples/printer.rs @@ -18,7 +18,7 @@ fn main() { "cpu" }); let printer = Printer::new(&device, 65536); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let id = dispatch_id().xy(); if_!(id.x().cmpeq(id.y()), { lc_info!(printer, "id = {:?}", id); diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 975f213..9eb4896 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -108,7 +108,7 @@ fn main() { let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); let debug_hit_t = device.create_buffer::(4); - let rt_kernel = device.create_kernel::<()>(&|| { + let rt_kernel = device.create_kernel::(&|| { let accel = accel.var(); let px = dispatch_id().xy(); let xy = px.float() / make_float2(img_w as f32, img_h as f32); diff --git a/luisa_compute/examples/raytracing.rs b/luisa_compute/examples/raytracing.rs index 338c0fe..aca84d1 100644 --- a/luisa_compute/examples/raytracing.rs +++ b/luisa_compute/examples/raytracing.rs @@ -43,7 +43,7 @@ fn main() { let img_w = 800; let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); - let rt_kernel = device.create_kernel::<()>(&|| { + let rt_kernel = device.create_kernel::(&|| { let accel = accel.var(); let px = dispatch_id().xy(); let xy = px.float() / make_float2(img_w as f32, img_h as f32); diff --git a/luisa_compute/examples/sdf_renderer.rs b/luisa_compute/examples/sdf_renderer.rs index 4b2e9ce..7a706bb 100644 --- a/luisa_compute/examples/sdf_renderer.rs +++ b/luisa_compute/examples/sdf_renderer.rs @@ -33,25 +33,25 @@ fn main() { radius: 1.0, }]); let x = device.create_buffer::(1024); - let shader = device.create_kernel::<(Buffer, Buffer)>(&|buf_x: BufferVar, - spheres: BufferVar< - Sphere, - >| { - let tid = dispatch_id().x(); - let o = make_float3(0.0, 0.0, -2.0); - let d = make_float3(0.0, 0.0, 1.0); - let sphere = spheres.read(0); - let t = var!(f32); - while_!(t.load().cmplt(10.0), { - let p = o + d * t.load(); - let d = (p - sphere.center()).length() - sphere.radius(); - if_!(d.cmplt(0.001), { - break_(); - }); - t.store(t.load() + d); - }); - buf_x.write(tid, t.load()); - }); + let shader = + device.create_kernel::, Buffer)>( + &|buf_x: BufferVar, spheres: BufferVar| { + let tid = dispatch_id().x(); + let o = make_float3(0.0, 0.0, -2.0); + let d = make_float3(0.0, 0.0, 1.0); + let sphere = spheres.read(0); + let t = var!(f32); + while_!(t.load().cmplt(10.0), { + let p = o + d * t.load(); + let d = (p - sphere.center()).length() - sphere.radius(); + if_!(d.cmplt(0.001), { + break_(); + }); + t.store(t.load() + d); + }); + buf_x.write(tid, t.load()); + }, + ); shader.dispatch([1024, 1, 1], &x, &spheres); let mut x_data = vec![f32::default(); 1024]; x.view(..).copy_to(&mut x_data); diff --git a/luisa_compute/examples/shadertoy.rs b/luisa_compute/examples/shadertoy.rs index ff7be00..534927e 100644 --- a/luisa_compute/examples/shadertoy.rs +++ b/luisa_compute/examples/shadertoy.rs @@ -21,15 +21,16 @@ fn main() { "cpu" }); - let palette = device.create_callable::<(Expr,), Expr>(&|d| { + let palette = device.create_callable::) -> Expr>(&|d| { make_float3(0.2, 0.7, 0.9).lerp(make_float3(1.0, 0.0, 1.0), Float3Expr::splat(d)) }); - let rotate = device.create_callable::<(Expr, Expr), Expr>(&|mut p, a| { - let c = a.cos(); - let s = a.sin(); - make_float2(p.dot(make_float2(c, s)), p.dot(make_float2(-s, c))) - }); - let map = device.create_callable::<(Expr, Expr), Expr>(&|mut p, time| { + let rotate = + device.create_callable::, Expr) -> Expr>(&|mut p, a| { + let c = a.cos(); + let s = a.sin(); + make_float2(p.dot(make_float2(c, s)), p.dot(make_float2(-s, c))) + }); + let map = device.create_callable::, Expr) -> Expr>(&|mut p, time| { for i in 0..8 { let t = time * 0.2; let r = rotate.call(p.xz(), t); @@ -40,7 +41,7 @@ fn main() { } Float3Expr::splat(1.0).copysign(p).dot(p) * 0.2 }); - let rm = device.create_callable::<(Expr, Expr, Expr), Expr>( + let rm = device.create_callable::, Expr, Expr)-> Expr>( &|ro, rd, time| { let t = var!(f32, 0.0); let col = var!(Float3); @@ -56,11 +57,11 @@ fn main() { make_float4(col.x(), col.y(), col.z(), 1.0 / (100.0 * *d)) }, ); - let clear_kernel = device.create_kernel::<(Tex2d,)>(&|img| { + let clear_kernel = device.create_kernel::,)>(&|img| { let coord = dispatch_id().xy(); img.write(coord, make_float4(0.3, 0.4, 0.5, 1.0)); }); - let main_kernel = device.create_kernel::<(Tex2d, f32)>(&|img, time| { + let main_kernel = device.create_kernel::, f32)>(&|img, time| { let xy = dispatch_id().xy(); let resolution = dispatch_size().xy(); let uv = (xy.float() - resolution.float() * 0.5) / resolution.x().float(); diff --git a/luisa_compute/examples/vecadd.rs b/luisa_compute/examples/vecadd.rs index c8bcda4..f6d9afd 100644 --- a/luisa_compute/examples/vecadd.rs +++ b/luisa_compute/examples/vecadd.rs @@ -23,7 +23,7 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::<(Buffer,)>(&|buf_z| { + let kernel = device.create_kernel::)>(&|buf_z| { // z is pass by arg let buf_x = x.var(); // x and y are captured let buf_y = y.var(); diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index 8e28be8..a22999b 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -2016,7 +2016,7 @@ impl KernelBuilder { entry, kind: ModuleKind::Kernel, pools: r.pools.clone().unwrap(), - flags: ModuleFlags::REQUIRES_AD_TRANSFORM, + flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM, }; let module = CallableModule { module: ir_module, @@ -2053,7 +2053,7 @@ impl KernelBuilder { entry, kind: ModuleKind::Kernel, pools: r.pools.clone().unwrap(), - flags: ModuleFlags::REQUIRES_AD_TRANSFORM, + flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM, }; let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); let module = KernelModule { @@ -2171,10 +2171,10 @@ unsafe impl CallableRet for T { } } -pub trait CallableSignature { +pub trait CallableSignature<'a> { type Callable; type DynCallable; - type Fn: CallableBuildFn + ?Sized; + type Fn: CallableBuildFn; type StaticFn: StaticCallableBuildFn; type DynFn: CallableBuildFn + 'static; type Ret: CallableRet; @@ -2182,16 +2182,16 @@ pub trait CallableSignature { fn create_dyn_callable(device: Device, init_once: bool, f: Self::DynFn) -> Self::DynCallable; } -pub trait KernelSignature { - type Fn: KernelBuildFn + ?Sized; +pub trait KernelSignature<'a> { + type Fn: KernelBuildFn; type Kernel; fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel; } macro_rules! impl_callable_signature { ()=>{ - impl CallableSignature for fn()->R { - type Fn = dyn Fn() ->R; + impl<'a, R: CallableRet +'static> CallableSignature<'a> for fn()->R { + type Fn = &'a dyn Fn() ->R; type DynFn = BoxR>; type StaticFn = fn() -> R; type Callable = CallableR>; @@ -2212,8 +2212,8 @@ macro_rules! impl_callable_signature { } }; ($first:ident $($rest:ident)*) => { - impl CallableSignature for fn($first, $($rest,)*)->R { - type Fn = dyn Fn($first, $($rest),*)->R; + impl<'a, R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature<'a> for fn($first, $($rest,)*)->R { + type Fn = &'a dyn Fn($first, $($rest),*)->R; type DynFn = BoxR>; type Callable = CallableR>; type StaticFn = fn($first, $($rest,)*)->R; @@ -2238,8 +2238,8 @@ macro_rules! impl_callable_signature { impl_callable_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_kernel_signature { ()=>{ - impl KernelSignature for fn() { - type Fn = dyn Fn(); + impl<'a> KernelSignature<'a> for fn() { + type Fn = &'a dyn Fn(); type Kernel = Kernel; fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { Self::Kernel{ @@ -2250,8 +2250,8 @@ macro_rules! impl_kernel_signature { } }; ($first:ident $($rest:ident)*) => { - impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature for fn($first, $($rest,)*) { - type Fn = dyn Fn($first::Parameter, $($rest::Parameter),*); + impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature<'a> for fn($first, $($rest,)*) { + type Fn = &'a dyn Fn($first::Parameter, $($rest::Parameter),*); type Kernel = Kernel; fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { Self::Kernel{ @@ -2267,7 +2267,7 @@ impl_kernel_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_callable_build_for_fn { ()=>{ - impl CallableBuildFn for dyn Fn()->R { + impl CallableBuildFn for &dyn Fn()->R { fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { builder.build_callable( |_| { self() @@ -2291,7 +2291,7 @@ macro_rules! impl_callable_build_for_fn { impl StaticCallableBuildFn for fn()->R {} }; ($first:ident $($rest:ident)*) => { - impl CallableBuildFn for dyn Fn($first, $($rest,)*)->R { + impl CallableBuildFn for &dyn Fn($first, $($rest,)*)->R { #[allow(non_snake_case)] fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { builder.build_callable( |builder| { @@ -2349,7 +2349,7 @@ macro_rules! impl_callable_build_for_fn { impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_kernel_build_for_fn { ()=>{ - impl KernelBuildFn for dyn Fn() { + impl KernelBuildFn for &dyn Fn() { fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { builder.build_kernel(options, |_| { self() @@ -2358,7 +2358,7 @@ macro_rules! impl_kernel_build_for_fn { } }; ($first:ident $($rest:ident)*) => { - impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for dyn Fn($first, $($rest,)*) { + impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for &dyn Fn($first, $($rest,)*) { #[allow(non_snake_case)] fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { builder.build_kernel(options, |builder| { @@ -2743,25 +2743,45 @@ pub fn return_() { struct AdContext { started: bool, backward_called: bool, + is_forward_mode: bool, + n_forward_grads: usize, // forward: Option>, } impl AdContext { - fn new() -> Self { + fn new_rev() -> Self { Self { started: false, backward_called: false, - // forward: None, + is_forward_mode: false, + n_forward_grads: 0, + } + } + fn new_fwd(n: usize) -> Self { + Self { + started: false, + backward_called: false, + is_forward_mode: true, + n_forward_grads: n, } } fn reset(&mut self) { - *self = Self::new(); + self.started = false; } } thread_local! { - static AD_CONTEXT:RefCell = RefCell::new(AdContext::new()); + static AD_CONTEXT:RefCell = RefCell::new(AdContext::new_rev()); } pub fn requires_grad(var: impl ExprProxy) { + AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert!(c.started, "autodiff section is not started"); + assert!( + !c.is_forward_mode, + "requires_grad() is called in forward mode" + ); + assert!(!c.backward_called, "backward is already called"); + }); __current_scope(|b| { b.call(Func::RequiresGradient, &[var.node()], Type::void()); }); @@ -2788,6 +2808,7 @@ pub fn backward_with_grad(out: T, grad: T) { AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); assert!(c.started, "autodiff section is not started"); + assert!(!c.is_forward_mode, "backward() is called in forward mode"); assert!(!c.backward_called, "backward is already called"); c.backward_called = true; }); @@ -2799,12 +2820,19 @@ pub fn backward_with_grad(out: T, grad: T) { }); } +/// Gradient of a value in *Reverse mode* AD pub fn gradient(var: T) -> T { + AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert!(c.started, "autodiff section is not started"); + assert!(!c.is_forward_mode, "gradient() is called in forward mode"); + assert!(c.backward_called, "backward is not called"); + }); T::from_node(__current_scope(|b| { b.call(Func::Gradient, &[var.node()], var.node().type_().clone()) })) } - +/// Gradient of a value in *Reverse mode* AD pub fn grad(var: T) -> T { gradient(var) } @@ -2834,6 +2862,73 @@ pub fn detach(v: T) -> T { T::from_node(node) } +/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input variable +pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { + AD_CONTEXT.with(|c| { + let mut c = c.borrow_mut(); + assert!(!c.started, "autodiff section already started"); + *c = AdContext::new_fwd(n_grads); + c.started = true; + }); + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + body(); + AD_CONTEXT.with(|c| { + let mut c = c.borrow_mut(); + c.reset(); + }); + let body = __pop_scope(); + __current_scope(|b| { + b.ad_scope(body, true); + }); +} + +/// Propagate N gradients w.r.t to input variable using *Forward mode* AD +pub fn propagate_gradient(v: T, grads: &[T]) { + AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert_eq!(grads.len(), c.n_forward_grads); + assert!(c.started, "autodiff section is not started"); + assert!( + c.is_forward_mode, + "propagate_gradient() is called in backward mode" + ); + }); + __current_scope(|b| { + let mut nodes = vec![v.node()]; + nodes.extend(grads.iter().map(|g| g.node())); + b.call(Func::PropagateGrad, &nodes, Type::void()); + }); +} + +pub fn output_gradients(v: T) -> Vec { + let n = AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert!(c.started, "autodiff section is not started"); + assert!( + c.is_forward_mode, + "output_gradients() is called in backward mode" + ); + c.n_forward_grads + }); + __current_scope(|b| { + let mut grads = vec![]; + for i in 0..n { + let idx = b.const_(Const::Int32(i as i32)); + grads.push(T::from_node(b.call( + Func::OutputGrad, + &[v.node(), idx], + Type::void(), + ))); + } + grads + }) +} + pub fn autodiff(body: impl Fn()) { AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); @@ -2855,8 +2950,7 @@ pub fn autodiff(body: impl Fn()) { }); let body = __pop_scope(); __current_scope(|b| { - let node = Node::new(CArc::new(Instruction::AdScope { body }), Type::void()); - b.append(new_node(b.pools(), node)) + b.ad_scope(body, false); }); } diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 15f20b6..8adb9f6 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -367,27 +367,30 @@ impl Device { modifications: RwLock::new(HashMap::new()), } } - pub fn create_callable(&self, f: &S::Fn) -> S::Callable { + pub fn create_callable<'a, S: CallableSignature<'a>>(&self, f:S::Fn) -> S::Callable { let mut builder = KernelBuilder::new(Some(self.clone()), false); - let raw_callable = CallableBuildFn::build_callable(f, None, &mut builder); + let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); S::wrap_raw_callable(raw_callable) } - pub fn create_dyn_callable(&self, f: S::DynFn) -> S::DynCallable { + pub fn create_dyn_callable<'a, S: CallableSignature<'a>>(&self, f: S::DynFn) -> S::DynCallable { S::create_dyn_callable(self.clone(), false, f) } - pub fn create_dyn_callable_once(&self, f: S::DynFn) -> S::DynCallable { + pub fn create_dyn_callable_once<'a, S: CallableSignature<'a>>( + &self, + f: S::DynFn, + ) -> S::DynCallable { S::create_dyn_callable(self.clone(), true, f) } - pub fn create_kernel(&self, f: &S::Fn) -> S::Kernel { + pub fn create_kernel<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { let mut builder = KernelBuilder::new(Some(self.clone()), true); let raw_kernel = - KernelBuildFn::build_kernel(f, &mut builder, KernelBuildOptions::default()); + KernelBuildFn::build_kernel(&f, &mut builder, KernelBuildOptions::default()); S::wrap_raw_kernel(raw_kernel) } - pub fn create_kernel_async(&self, f: &S::Fn) -> S::Kernel { + pub fn create_kernel_async<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { let mut builder = KernelBuilder::new(Some(self.clone()), true); let raw_kernel = KernelBuildFn::build_kernel( - f, + &f, &mut builder, KernelBuildOptions { async_compile: true, @@ -396,18 +399,18 @@ impl Device { ); S::wrap_raw_kernel(raw_kernel) } - pub fn create_kernel_with_options( + pub fn create_kernel_with_options<'a, S: KernelSignature<'a>>( &self, - f: &S::Fn, + f: S::Fn, options: KernelBuildOptions, ) -> S::Kernel { let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = KernelBuildFn::build_kernel(f, &mut builder, options); + let raw_kernel = KernelBuildFn::build_kernel(&f, &mut builder, options); S::wrap_raw_kernel(raw_kernel) } } -pub fn create_static_callable(f: S::StaticFn) -> S::Callable { +pub fn create_static_callable<'a, S: CallableSignature<'a>>(f: S::StaticFn) -> S::Callable { let r_backup = RECORDER.with(|r| { let mut r = r.borrow_mut(); std::mem::replace(&mut *r, Recorder::new()) @@ -1152,22 +1155,22 @@ impl RawKernel { } } -pub struct Callable { +pub struct Callable> { #[allow(dead_code)] pub(crate) inner: RawCallable, pub(crate) _marker: std::marker::PhantomData, } -pub(crate) struct DynCallableInner { +pub(crate) struct DynCallableInner> { builder: Box, &mut KernelBuilder) -> Callable>, callables: Vec>, } -pub struct DynCallable { +pub struct DynCallable> { #[allow(dead_code)] pub(crate) inner: RefCell>, pub(crate) device: Device, pub(crate) init_once: bool, } -impl DynCallable { +impl> DynCallable { pub(crate) fn new( device: Device, init_once: bool, @@ -1238,13 +1241,13 @@ pub struct RawCallable { pub(crate) resource_tracker: ResourceTracker, } -pub struct Kernel { +pub struct Kernel> { pub(crate) inner: RawKernel, pub(crate) _marker: std::marker::PhantomData, } -unsafe impl Send for Kernel {} -unsafe impl Sync for Kernel {} -impl Kernel { +unsafe impl> Send for Kernel {} +unsafe impl> Sync for Kernel {} +impl> Kernel { pub fn cache_dir(&self) -> Option { let handle = self.inner.unwrap(); let device = &self.inner.device; @@ -1335,7 +1338,7 @@ impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_dispatch_for_kernel { ($first:ident $($rest:ident)*) => { - impl <$first:KernelArg, $($rest: KernelArg),*> Kernel { + impl <$first:KernelArg+'static, $($rest: KernelArg+'static),*> Kernel { #[allow(non_snake_case)] pub fn dispatch(&self, dispatch_size: [u32; 3], $first:&impl AsKernelArg<$first>, $($rest:&impl AsKernelArg<$rest>),*) { let mut encoder = KernelArgEncoder::new(); diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 442caed..98ea5c9 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -91,7 +91,7 @@ fn autodiff_helper Float>( // inputs[i].view(..).copy_from(&tmp); // } println!("init time: {:?}", tic.elapsed()); - let kernel = device.create_kernel_async::<()>(&|| { + let kernel = device.create_kernel_async::(&|| { let input_vars = inputs.iter().map(|input| input.var()).collect::>(); let grad_fd_vars = grad_fd.iter().map(|grad| grad.var()).collect::>(); let grad_ad_vars = grad_ad.iter().map(|grad| grad.var()).collect::>(); @@ -753,7 +753,7 @@ fn autodiff_if_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); @@ -804,7 +804,7 @@ fn autodiff_if_phi() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); @@ -854,7 +854,7 @@ fn autodiff_if_phi2() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); @@ -910,7 +910,7 @@ fn autodiff_if_phi3() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); @@ -971,7 +971,7 @@ fn autodiff_if_phi4() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); @@ -1035,7 +1035,7 @@ fn autodiff_switch() { t.view(..).fill_fn(|_| rng.gen_range(0..3)); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_t = t.var(); let buf_x = x.var(); let buf_y = y.var(); @@ -1094,7 +1094,7 @@ fn autodiff_callable() { t.view(..).fill_fn(|_| rng.gen_range(0..3)); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let callable = device.create_callable::<(Var, Var, Expr), ()>(&|vx, vy, t| { + let callable = device.create_callable::, Var, Expr)>(&|vx, vy, t| { let x = *vx; let y = *vy; autodiff(|| { @@ -1110,7 +1110,7 @@ fn autodiff_callable() { *vy.get_mut() = gradient(y); }); }); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_t = t.var(); let buf_x = x.var(); let buf_y = y.var(); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 8b03733..5d57835 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -38,10 +38,10 @@ fn event() { let a: Buffer = device.create_buffer_from_slice(&[0]); let b: Buffer = device.create_buffer_from_slice(&[0]); // compute (1 + 3) * (4 + 5) - let add = device.create_kernel::<(Buffer, i32)>(&|buf: BufferVar, v: Expr| { + let add = device.create_kernel::, i32)>(&|buf: BufferVar, v: Expr| { buf.write(0, buf.read(0) + v); }); - let mul = device.create_kernel::<(Buffer, Buffer)>( + let mul = device.create_kernel::, Buffer)>( &|a: BufferVar, b: BufferVar| { a.write(0, a.read(0) * b.read(0)); }, @@ -85,20 +85,20 @@ fn event() { #[test] fn callable() { let device = get_device(); - let write = device.create_callable::<(BufferVar, Expr, Var), ()>( + let write = device.create_callable::, Expr, Var)>( &|buf: BufferVar, i: Expr, v: Var| { buf.write(i, v.load()); v.store(v.load() + 1); }, ); - let add = device.create_callable::<(Expr, Expr), Expr>(&|a, b| a + b); + let add = device.create_callable::, Expr)->Expr>(&|a, b| a + b); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); let w = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as u32); y.view(..).fill_fn(|i| 1000 * i as u32); - let kernel = device.create_kernel::<(Buffer,)>(&|buf_z| { + let kernel = device.create_kernel::)>(&|buf_z| { let buf_x = x.var(); let buf_y = y.var(); let buf_w = w.var(); @@ -124,7 +124,7 @@ fn vec_cast() { let i: Buffer = device.create_buffer(1024); f.view(..) .fill_fn(|i| Float2::new(i as f32 + 0.5, i as f32 + 1.5)); - let kernel = device.create_kernel_with_options::<()>( + let kernel = device.create_kernel_with_options::( &|| { let f = f.var(); let i = i.var(); @@ -157,7 +157,7 @@ fn bool_op() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let tid = dispatch_id().x(); let x = x.var().read(tid); let y = y.var().read(tid); @@ -198,7 +198,7 @@ fn bvec_op() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| Bool2::new(rng.gen(), rng.gen())); y.view(..).fill_fn(|_| Bool2::new(rng.gen(), rng.gen())); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let tid = dispatch_id().x(); let x = x.var().read(tid); let y = y.var().read(tid); @@ -248,7 +248,7 @@ fn vec_bit_minmax() { x.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); y.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); z.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let tid = dispatch_id().x(); let x = x.var().read(tid); let y = y.var().read(tid); @@ -307,7 +307,7 @@ fn vec_permute() { let v3: Buffer = device.create_buffer(1024); v2.view(..) .fill_fn(|i| Int2::new(i as i32 + 0, i as i32 + 1)); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let v2 = v2.var(); let v3 = v3.var(); let tid = dispatch_id().x(); @@ -330,7 +330,7 @@ fn if_phi() { let x: Buffer = device.create_buffer(1024); let even: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let x = x.var(); let even = even.var(); let tid = dispatch_id().x(); @@ -353,7 +353,7 @@ fn switch_phi() { let y: Buffer = device.create_buffer(1024); let z: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_z = z.var(); @@ -400,7 +400,7 @@ fn switch_unreachable() { let y: Buffer = device.create_buffer(1024); let z: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32 % 3); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_z = z.var(); @@ -442,7 +442,7 @@ fn switch_unreachable() { fn array_read_write() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); let arr = local_zeroed::<[i32; 4]>(); @@ -466,7 +466,7 @@ fn array_read_write() { fn array_read_write3() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); let arr = local_zeroed::<[i32; 4]>(); @@ -488,7 +488,7 @@ fn array_read_write3() { fn array_read_write4() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); let arr = local_zeroed::<[i32; 4]>(); @@ -518,7 +518,7 @@ fn array_read_write2() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x(); @@ -548,7 +548,7 @@ fn array_read_write_vla() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x(); @@ -583,7 +583,7 @@ fn array_read_write_vla() { fn array_read_write_async_compile() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); let arr = local_zeroed::<[i32; 4]>(); @@ -610,7 +610,7 @@ fn capture_same_buffer_multiple_view() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::<()>(&|| { + let shader = device.create_kernel::(&|| { let tid = luisa::dispatch_id().x(); let buf_x_lo = x.view(0..64).var(); let buf_x_hi = x.view(64..).var(); @@ -638,7 +638,7 @@ fn uniform() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::<(Float3,)>(&|v: Expr| { + let shader = device.create_kernel::(&|v: Expr| { let tid = luisa::dispatch_id().x(); let buf_x_lo = x.view(0..64).var(); let buf_x_hi = x.view(64..).var(); @@ -688,7 +688,7 @@ fn byte_buffer() { let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); device - .create_kernel::<()>(&|| { + .create_kernel::(&|| { let buf = buf.var(); let i0 = i0 as u64; let i1 = i1 as u64; @@ -759,7 +759,7 @@ fn bindless_byte_buffer() { let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); device - .create_kernel::<(ByteBuffer,)>(&|out:ByteBufferVar| { + .create_kernel::(&|out:ByteBufferVar| { let heap = heap.var(); let buf = heap.byte_address_buffer(0); let i0 = i0 as u64; diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 5dbfcd9..259b58d 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 5dbfcd971e969d66dc4c5abd7ba69ab1231c1336 +Subproject commit 259b58d69e67f311bf9de31d19e9c38ea3d4c6ff