diff --git a/Cargo.toml b/Cargo.toml index 32ec0df..265f301 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,9 @@ members = [ "luisa_compute_sys", "luisa_compute_derive_impl", "luisa_compute_derive", + "luisa_compute_track", ] +resolver = "2" # exclude = [ # "luisa_compute_sys/LuisaCompute/src/api/luisa_compute_api_types", -# ] \ No newline at end of file +# ] diff --git a/README.md b/README.md index 5dc5a12..9ef3840 100644 --- a/README.md +++ b/README.md @@ -3,26 +3,33 @@ Rust frontend to LuisaCompute and more! Unified API and embedded DSL for high pe To see the use of `luisa-compute-rs` in a high performance offline rendering system, checkout [our research renderer](https://github.com/shiinamiyuki/akari_render) ## Table of Contents -* [Overview](#overview) - + [Embedded Domain-Specific Language](#embedded-domain-specific-language) - + [Automatic Differentiation](#automatic-differentiation) - + [A CPU backend](#cpu-backend) - + [IR Module for EDSL](#ir-module) - + [Debuggability](#debuggability) -* [Usage](#usage) - + [Building](#building) - + [Variables and Expressions](#variables-and-expressions) - + [Builtin Functions](#builtin-functions) - + [Control Flow](#control-flow) - + [Custom Data Types](#custom-data-types) - + [Polymorphism](#polymorphism) - + [Autodiff](#autodiff) - + [Custom Operators](#custom-operators) - + [Callable](#callable) - + [Kernel](#kernel) -* [Advanced Usage](#advanced-usage) -* [Safety](#safety) -* [Citation](#citation) +- [luisa-compute-rs](#luisa-compute-rs) + - [Table of Contents](#table-of-contents) + - [Example](#example) + - [Vecadd](#vecadd) + - [Overview](#overview) + - [Embedded Domain-Specific Language](#embedded-domain-specific-language) + - [Automatic Differentiation](#automatic-differentiation) + - [CPU Backend](#cpu-backend) + - [IR Module](#ir-module) + - [Debuggability](#debuggability) + - [Usage](#usage) + - [Building](#building) + - [Variables and Expressions](#variables-and-expressions) + - [Builtin Functions](#builtin-functions) + - [Control Flow](#control-flow) + - [`track!` Mcro](#track-mcro) + - [Custom Data Types](#custom-data-types) + - [Polymorphism](#polymorphism) + - [Autodiff](#autodiff) + - [Custom Operators](#custom-operators) + - [Callable](#callable) + - [Kernel](#kernel) + - [Advanced Usage](#advanced-usage) + - [Safety](#safety) + - [API](#api) + - [Backend](#backend) + - [Citation](#citation) ## Example Try `cargo run --release --example path_tracer -- [cpu|cuda|dx|metal]`! @@ -60,7 +67,7 @@ fn main() { let tid = dispatch_id().x(); let x = buf_x.read(tid); let y = buf_y.read(tid); - let vx = var!(f32); // create a local mutable variable + let vx = 0.0f32.var(); // create a local mutable variable *vx.get_mut() += x; buf_z.write(tid, vx.load() + y); }); @@ -125,26 +132,26 @@ For each type, there are two EDSL proxy objects `Expr` and `Var`. `Expr *Note*: Every DSL object in host code **must** be immutable due to Rust unable to overload `operator =`. For example: ```rust // **no good** -let mut v = const_(0.0f32); +let mut v = 0.0f32.expr(); if_!(cond, { v += 1.0; }); // also **not good** -let v = Cell::new(const_(0.0f32)); +let v = Cell::new(0.0f32.expr()); if_!(cond, { v.set(v.get() + 1.0); }); // **good** -let v = var!(f32); +let v = 0.0f32.var(); if_!(cond, { *v.get_mut() += 1.0; }); ``` *Note*: You should not store the referene obtained by `v.get_mut()` for repeated use, as the assigned value is only updated when `v.get_mut()` is dropped. For example,: ```rust -let v = var!(f32); +let v = 0.0f32.var(); let bad = v.get_mut(); *bad = 1.0; let u = *v; @@ -152,15 +159,7 @@ drop(bad); cpu_dbg!(u); // prints 0.0 cpu_dbg!(*v); // prints now 1.0 ``` -All operations except load/store should be performed on `Expr`. `Var` can only be used to load/store values. While `Expr` and `Var` are sufficent in most cases, it cannot be placed in an `impl` block. To do so, the exact name of these proxies are needed. -```rust -Expr == Bool, Var == BoolVar -Expr == Float32, Var == Float32Var -Expr == Int32, Var == Int32Var -Expr == UInt32, Var == UInt32Var -Expr == Int64, Var == Int64Var -Expr == UInt64, Var == UInt64Var -``` +All operations except load/store should be performed on `Expr`. `Var` can only be used to load/store values. As in the C++ EDSL, we additionally supports the following vector/matrix types. Their proxy types are `XXXExpr` and `XXXVar`: @@ -168,9 +167,9 @@ As in the C++ EDSL, we additionally supports the following vector/matrix types. Bool2 // bool2 in C++ Bool3 // bool3 in C++ Bool4 // bool4 in C++ -Vec2 // float2 in C++ -Vec3 // float3 in C++ -Vec4 // float4 in C++ +Float2 // float2 in C++ +Float3 // float3 in C++ +Float4 // float4 in C++ Int2 // int2 in C++ Int3 // int3 in C++ Int4 // int4 in C++ @@ -181,20 +180,19 @@ Mat2 // float2x2 in C++ Mat3 // float3x3 in C++ Mat4 // float4x4 in C++ ``` -Array types `[T;N]` are also supported and their proxy types are `ArrayExpr` and `ArrayVar`. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar` for element access. `ArrayExpr` can be stored to and loaded from `ArrayVar`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar`. `VLArrayVar::zero(length: usize` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar` in host, use ``VLArrayVar::static_len()->usize`. To query the length in kernel, use ``VLArrayVar::len()->Expr` +Array types `[T;N]` are also supported and their proxy types are `ArrayExpr` and `ArrayVar`. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar` for element access. `ArrayExpr` can be stored to and loaded from `ArrayVar`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar`. `VLArrayVar::zero(length: usize)` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar` in host, use `VLArrayVar::static_len()->usize`. To query the length in kernel, use `VLArrayVar::len()->Expr` Most operators are already overloaded with the only exception is comparision. We cannot overload comparision operators as `PartialOrd` cannot return a DSL type. Instead, use `cmpxx` methods such as `cmpgt, cmpeq`, etc. To cast a primitive/vector into another type, use `v.type()`. For example: ```rust -let iv = make_int2(1,1,1); +let iv = Int2::expr(1, 1, 1); let fv = iv.float(); //fv is Expr let bv = fv.bool(); // bv is Expr ``` -To perform a bitwise cast, use the `bitcast` function. `let fv:Expr = bitcast::(const_(0u32));` +To perform a bitwise cast, use the `bitcast` function. `let fv:Expr = bitcast::(0u32);` ### Builtin Functions -We have extentded primitive types with methods similar to their host counterpart: `v.sin(), v.max(u)`, etc. Most methods accepts both a `Expr` or a literal like `0.0`. However, the `select` function is slightly different as it do not accept literals. You need to use `select(cond, f_var, const_(1.0f32))`. - +We have extentded primitive types with methods similar to their host counterpart: `v.sin(), v.max(u)`, etc. Most methods accepts both a `Expr` or a literal like `0.0`. However, the `select` function is slightly different as it does not accept literals. You need to use `select(cond, f_var, 1.0f32.expr())`. ### Control Flow *Note*, you cannot modify outer scope variables inside a control flow block by declaring the variable as `mut`. To modify outer scope variables, use `Var` instead and call *var.get_mut() = value` to store the value back to the outer scope. @@ -223,8 +221,60 @@ let (x,y) = switch::<(Expr, Expr)>(value) .finish(); ``` +### `track!` Mcro + +We also offer a `track!` macro that automatically rewrites control flow primitves and comparison operators. For example (from [`examples/mpm.rs`](luisa_compute/examples/mpm.rs)): + +```rust +track!(|| { + // ... + let vx = select( + coord.x() < BOUND && (vx < 0.0f32) + || coord.x() + BOUND > N_GRID as u32 && (vx > 0.0f32), + 0.0f32.into(), + vx, + ); + let vy = select( + coord.y() < BOUND && (vy < 0.0f32) + || coord.y() + BOUND > N_GRID as u32 && (vy > 0.0f32), + 0.0f32.into(), + vy, + ); + // ... +}) +``` +is equivalent to: +```rust +|| { + // ... + let vx = select( + (coord.x().cmplt(BOUND) & vx.cmplt(0.0f32)) + | (coord.x() + BOUND).cmpgt(N_GRID as u32) & vx.cmpgt(0.0f32), + 0.0f32.into(), + vx, + ); + let vy = select( + (coord.y().cmplt(BOUND) & vy.cmplt(0.0f32)) + | (coord.y() + BOUND).cmpgt(N_GRID as u32) & vy.cmpgt(0.0f32), + 0.0f32.into(), + vy, + ); + // ... +} +``` +Similarily, +```rust +track!(if cond { foo } else if bar { baz } else { qux }) +``` +will be converted to +```rust +if_!(cond, { foo }, { if_!(bar, { baz }, { qux }) }) +``` + +Note that this macro will rewrite `while`, `for _ in x..y`, and `loop` expressions to versions using functions, which will then break the `break` and `continue` expressions. In order to avoid this, it's possible to use the `escape!` macro within a `track!` context to disable rewriting for an expression. + ### Custom Data Types -To add custom data types to the EDSL, simply derive from `luisa::Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. The proxy types are `XXXExpr` and `XXXVar`: +To add custom data types to the EDSL, simply derive from `Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. The proxy types are `XXXExpr` and `XXXVar`: ```rust #[derive(Copy, Clone, Default, Debug, Value)] @@ -234,7 +284,7 @@ pub struct MyVec2 { pub y: f32, } -let v = var!(MyVec2); +let v = MyVec2.var(); let sum = *v.x() + *v.y(); *v.x().get_mut() += 1.0; ``` @@ -282,8 +332,6 @@ autodiff(||{ buf_dv.write(.., dv); buf_dm.write(.., dm); }); - - ``` ### Custom Operators @@ -304,8 +352,8 @@ let my_add = CpuFn::new(|args: &mut MyAddArgs| { let args = MyAddArgsExpr::new(x, y, Float32::zero()); let result = my_add.call(args); - ``` + ### Callable Users can define device-only functions using Callables. Callables have similar type signature to kernels: `CallableRet>`. The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using `Device::create_callable`. To invoke a Callable, use `Callable::call(args...)`. Callables accepts arguments such as resources (`BufferVar`, .etc), expressions and references (pass a `Var` to the callable). For example: @@ -317,7 +365,7 @@ let z = add.call(x, y); let pass_by_ref = device.create_callable::)>(&|a| { *a.get_mut() += 1.0; }); -let a = var!(f32, 1.0); +let a = 1.0f32.var(); pass_by_ref.call(a); cpu_dbg!(*a); // prints 2.0 ``` diff --git a/luisa_compute/Cargo.toml b/luisa_compute/Cargo.toml index 3698537..dbf69d2 100644 --- a/luisa_compute/Cargo.toml +++ b/luisa_compute/Cargo.toml @@ -4,7 +4,7 @@ name = "luisa_compute" version = "0.1.1-alpha.1" [dependencies] -base64ct = {version = "1.5.0", features = ["alloc"]} +base64ct = { version = "1.5.0", features = ["alloc"] } bumpalo = "3.12.0" env_logger = "0.10.0" glam = "0.24.0" @@ -14,15 +14,16 @@ lazy_static = "1.4.0" libc = "0.2" libloading = "0.8" log = "0.4" -luisa_compute_api_types = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version="0.1.1-alpha.1"} -luisa_compute_backend = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version="0.1.1-alpha.1"} -luisa_compute_derive = {path = "../luisa_compute_derive", version="0.1.1-alpha.1"} -luisa_compute_derive_impl = {path = "../luisa_compute_derive_impl", version="0.1.1-alpha.1"} -luisa_compute_ir = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version="0.1.1-alpha.1"} -luisa_compute_sys = {path = "../luisa_compute_sys", version="0.1.1-alpha.1"} +luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" } +luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" } +luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" } +luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" } +luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" } +luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" } +luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" } parking_lot = "0.12.1" rayon = "1.6.0" -serde = {version = "1.0", features = ["derive"]} +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha2 = "0.10" winit = "0.28.3" diff --git a/luisa_compute/examples/atomic.rs b/luisa_compute/examples/atomic.rs index f686e00..c3c1857 100644 --- a/luisa_compute/examples/atomic.rs +++ b/luisa_compute/examples/atomic.rs @@ -1,7 +1,6 @@ use std::env::current_exe; use luisa::prelude::*; -use luisa::Context; use luisa_compute as luisa; fn main() { @@ -11,12 +10,12 @@ fn main() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::(&|| { + let shader = device.create_kernel::(&track!(|| { let buf_x = x.var(); let buf_sum = sum.var(); - let tid = luisa::dispatch_id().x(); + let tid = dispatch_id().x(); buf_sum.atomic_fetch_add(0, buf_x.read(tid)); - }); + })); shader.dispatch([x.len() as u32, 1, 1]); let mut sum_data = vec![0.0]; sum.view(..).copy_to(&mut sum_data); diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 2cb8f0f..720e858 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -1,6 +1,8 @@ -use std::{env::current_exe, f32::consts::PI}; +use std::env::current_exe; +use std::f32::consts::PI; -use luisa::*; +use luisa::lang::diff::*; +use luisa::prelude::*; use luisa_compute as luisa; fn main() { luisa::init_logger_verbose(); @@ -31,11 +33,13 @@ fn main() { let buf_y = y.var(); let x = buf_x.read(tid); let y = buf_y.read(tid); - let f = |x: Expr, y: Expr| { - if_!(x.cmpgt(y), { x * y }, else, { + let f = track!(|x: Expr, y: Expr| { + if x > y { + x * y + } else { y * x + (x / 32.0 * PI).sin() - }) - }; + } + }); autodiff(|| { requires_grad(x); requires_grad(y); @@ -45,8 +49,8 @@ fn main() { dy_rev.write(tid, gradient(y)); }); forward_autodiff(2, || { - propagate_gradient(x, &[const_(1.0f32), const_(0.0f32)]); - propagate_gradient(y, &[const_(0.0f32), const_(1.0f32)]); + propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]); + propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]); let z = f(x, y); let dx = output_gradients(z)[0]; let dy = output_gradients(z)[1]; diff --git a/luisa_compute/examples/backtrace.rs b/luisa_compute/examples/backtrace.rs index 373f054..d1fabdf 100644 --- a/luisa_compute/examples/backtrace.rs +++ b/luisa_compute/examples/backtrace.rs @@ -1,10 +1,10 @@ -use std::env::{current_exe, self}; +use std::env::{self, current_exe}; +use luisa::prelude::*; use luisa_compute as luisa; fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); let ctx = Context::new(current_exe().unwrap()); env::set_var("LUISA_DEBUG", "1"); let device = ctx.create_device("cpu"); @@ -20,7 +20,7 @@ fn main() { let tid = dispatch_id().x(); let x = buf_x.read(tid + 123); let y = buf_y.read(tid); - let vx = var!(f32); // create a local mutable variable + let vx = Var::::zeroed(); // create a local mutable variable vx.store(x); buf_z.write(tid, vx.load() + y); }); diff --git a/luisa_compute/examples/bindgroup.rs b/luisa_compute/examples/bindgroup.rs index c330331..a2965fb 100644 --- a/luisa_compute/examples/bindgroup.rs +++ b/luisa_compute/examples/bindgroup.rs @@ -1,6 +1,6 @@ use std::env::current_exe; -use luisa::*; +use luisa::prelude::*; use luisa_compute as luisa; #[derive(BindGroup)] struct MyArgStruct { diff --git a/luisa_compute/examples/bindless.rs b/luisa_compute/examples/bindless.rs index bd4d719..931c354 100644 --- a/luisa_compute/examples/bindless.rs +++ b/luisa_compute/examples/bindless.rs @@ -1,7 +1,8 @@ -use std::{env::current_exe, path::PathBuf}; +use std::env::current_exe; +use std::path::PathBuf; use image::io::Reader as ImageReader; -use luisa::*; +use luisa::prelude::*; use luisa_compute as luisa; fn main() { @@ -12,7 +13,7 @@ fn main() { args[0] ); - init_logger(); + luisa::init_logger(); let ctx = Context::new(current_exe().unwrap()); let device = ctx.create_device(if args.len() == 2 { args[1].as_str() @@ -63,8 +64,8 @@ fn main() { let kernel = device.create_kernel::)>(&|buf_z| { let bindless = bindless.var(); let tid = dispatch_id().x(); - let buf_x = bindless.buffer::(Uint::from(0)); - let buf_y = bindless.buffer::(Uint::from(1)); + let buf_x = bindless.buffer::(0_u32.expr()); + let buf_y = bindless.buffer::(1_u32.expr()); let x = buf_x.read(tid).uint().float(); let y = buf_y.read(tid); buf_z.write(tid, x + y); diff --git a/luisa_compute/examples/callable.rs b/luisa_compute/examples/callable.rs index 1029317..1882950 100644 --- a/luisa_compute/examples/callable.rs +++ b/luisa_compute/examples/callable.rs @@ -1,13 +1,9 @@ -use luisa::derive::*; -use luisa_compute as luisa; -use luisa::Value; use luisa::prelude::*; +use luisa_compute as luisa; use std::env::current_exe; fn main() { - use luisa::*; - init_logger(); - let ctx = Context::new(current_exe().unwrap()); + luisa::init_logger(); let args: Vec = std::env::args().collect(); assert!( args.len() <= 2, @@ -21,7 +17,7 @@ fn main() { } else { "cpu" }); - let add = device.create_callable::, Expr)->Expr>(&|a, b| a + b); + let add = device.create_callable::, Expr) -> Expr>(&|a, b| a + b); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); diff --git a/luisa_compute/examples/callable_advanced.rs b/luisa_compute/examples/callable_advanced.rs index 74a0e61..a76cc9a 100644 --- a/luisa_compute/examples/callable_advanced.rs +++ b/luisa_compute/examples/callable_advanced.rs @@ -1,10 +1,10 @@ +use luisa::lang::types::dynamic::*; +use luisa::prelude::*; use luisa_compute as luisa; use std::env::current_exe; fn main() { - use luisa::*; - init_logger(); - let ctx = Context::new(current_exe().unwrap()); + luisa::init_logger(); let args: Vec = std::env::args().collect(); assert!( args.len() <= 2, diff --git a/luisa_compute/examples/custom_aggregate.rs b/luisa_compute/examples/custom_aggregate.rs index aae132c..fca5a8f 100644 --- a/luisa_compute/examples/custom_aggregate.rs +++ b/luisa_compute/examples/custom_aggregate.rs @@ -1,8 +1,8 @@ -use luisa::*; +use luisa::prelude::*; use luisa_compute as luisa; #[derive(Aggregate)] pub struct Spectrum { - samples: Vec, + samples: Vec>, } #[derive(Aggregate)] diff --git a/luisa_compute/examples/custom_op.rs b/luisa_compute/examples/custom_op.rs index 8931ab2..501478e 100644 --- a/luisa_compute/examples/custom_op.rs +++ b/luisa_compute/examples/custom_op.rs @@ -1,7 +1,7 @@ use std::env::current_exe; -use luisa::lang::*; -use luisa::Value; +use luisa::lang::debug::CpuFn; +use luisa::prelude::*; use luisa_compute as luisa; #[derive(Clone, Copy, Value, Debug)] #[repr(C)] @@ -12,8 +12,6 @@ pub struct MyAddArgs { } fn main() { - use luisa::*; - let ctx = Context::new(current_exe().unwrap()); let device = ctx.create_device("cpu"); let x = device.create_buffer::(1024); @@ -29,23 +27,21 @@ fn main() { println!("Hello from thread 0!"); } }); - let shader = device - .create_kernel::)>(&|buf_z: BufferVar| { - // z is pass by arg - let buf_x = x.var(); // x and y are captured - let buf_y = y.var(); - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let args = MyAddArgsExpr::new(x, y, Float::zero()); - let result = my_add.call(args); - let _ = my_print.call(tid); - if_!(tid.cmpeq(0), { - cpu_dbg!(args); - }); - buf_z.write(tid, result.result()); - }) - ; + let shader = device.create_kernel::)>(&|buf_z: BufferVar| { + // z is pass by arg + let buf_x = x.var(); // x and y are captured + let buf_y = y.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let args = MyAddArgsExpr::new(x, y, Expr::::zero()); + let result = my_add.call(args); + let _ = my_print.call(tid); + if_!(tid.cmpeq(0), { + cpu_dbg!(args); + }); + buf_z.write(tid, result.result()); + }); shader.dispatch([1024, 1, 1], &z); let mut z_data = vec![0.0; 1024]; z.view(..).copy_to(&mut z_data); diff --git a/luisa_compute/examples/find_leak.rs b/luisa_compute/examples/find_leak.rs index c8d573b..8f4a58d 100644 --- a/luisa_compute/examples/find_leak.rs +++ b/luisa_compute/examples/find_leak.rs @@ -1,10 +1,10 @@ use std::env::current_exe; +use luisa::prelude::*; use luisa_compute as luisa; fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); let args: Vec = std::env::args().collect(); assert!( args.len() <= 2, diff --git a/luisa_compute/examples/fluid.rs b/luisa_compute/examples/fluid.rs index 6901b2a..f613df1 100644 --- a/luisa_compute/examples/fluid.rs +++ b/luisa_compute/examples/fluid.rs @@ -1,22 +1,18 @@ #![allow(non_snake_case)] +use std::env::current_exe; use std::mem::swap; -use std::{env::current_exe, time::Instant}; +use std::time::Instant; -use luisa::init_logger; -#[allow(unused_imports)] use luisa::prelude::*; -use luisa::*; use luisa_compute as luisa; -use winit::{ - event::{Event, WindowEvent}, - event_loop::{ControlFlow, EventLoop}, -}; +use winit::event::{Event, WindowEvent}; +use winit::event_loop::{ControlFlow, EventLoop}; const N_GRID: i32 = 512; fn main() { - init_logger(); + luisa::init_logger(); std::env::set_var("WINIT_UNIX_BACKEND", "x11"); let args: Vec = std::env::args().collect(); if args.len() > 2 { @@ -72,17 +68,17 @@ fn main() { let index = |xy: Expr| -> Expr { let p = xy.clamp( - make_uint2(0, 0), - make_uint2(N_GRID as u32 - 1, N_GRID as u32 - 1), + Uint2::expr(0, 0), + Uint2::expr(N_GRID as u32 - 1, N_GRID as u32 - 1), ); p.x() + p.y() * N_GRID as u32 }; - let lookup_float = |f: &BufferVar, x: Int, y: Int| -> Float { - return f.read(index(make_uint2(x.uint(), y.uint()))); + let lookup_float = |f: &BufferVar, x: Expr, y: Expr| -> Expr { + return f.read(index(Uint2::expr(x.uint(), y.uint()))); }; - let sample_float = |f: BufferVar, x: Float, y: Float| -> Float { + let sample_float = |f: BufferVar, x: Expr, y: Expr| -> Expr { let lx = x.floor().int(); let ly = y.floor().int(); @@ -95,22 +91,22 @@ fn main() { return s0.lerp(s1, ty); }; - let lookup_vel = |f: &BufferVar, x: Int, y: Int| -> Float2Expr { - return f.read(index(make_uint2(x.uint(), y.uint()))); + let lookup_vel = |f: &BufferVar, x: Expr, y: Expr| -> Expr { + return f.read(index(Uint2::expr(x.uint(), y.uint()))); }; - let sample_vel = |f: BufferVar, x: Float, y: Float| -> Float2Expr { + let sample_vel = |f: BufferVar, x: Expr, y: Expr| -> Expr { let lx = x.floor().int(); let ly = y.floor().int(); let tx = x - lx.float(); let ty = y - ly.float(); - let s0 = lookup_vel(&f, lx, ly).lerp(lookup_vel(&f, lx + 1, ly), make_float2(tx, tx)); + let s0 = lookup_vel(&f, lx, ly).lerp(lookup_vel(&f, lx + 1, ly), Float2::expr(tx, tx)); let s1 = - lookup_vel(&f, lx, ly + 1).lerp(lookup_vel(&f, lx + 1, ly + 1), make_float2(tx, tx)); + lookup_vel(&f, lx, ly + 1).lerp(lookup_vel(&f, lx + 1, ly + 1), Float2::expr(tx, tx)); - return s0.lerp(s1, make_float2(ty, ty)); + return s0.lerp(s1, Float2::expr(ty, ty)); }; let advect = device @@ -120,7 +116,7 @@ fn main() { let u = u0.read(index(coord)); // trace backward - let mut p = make_float2(coord.x().float(), coord.y().float()); + let mut p = Float2::expr(coord.x().float(), coord.y().float()); p = p - u * dt; // advect @@ -132,10 +128,10 @@ fn main() { let divergence = device.create_kernel_async::, Buffer)>(&|u, div| { let coord = dispatch_id().xy(); if_!(coord.x().cmplt(N_GRID - 1) & coord.y().cmplt(N_GRID - 1), { - let dx = (u.read(index(make_uint2(coord.x() + 1, coord.y()))).x() + let dx = (u.read(index(Uint2::expr(coord.x() + 1, coord.y()))).x() - u.read(index(coord)).x()) * 0.5; - let dy = (u.read(index(make_uint2(coord.x(), coord.y() + 1))).y() + let dy = (u.read(index(Uint2::expr(coord.x(), coord.y() + 1))).y() - u.read(index(coord)).y()) * 0.5; div.write(index(coord), dx + dy); @@ -169,11 +165,11 @@ fn main() { i.cmpgt(0) & i.cmplt(N_GRID - 1) & j.cmpgt(0) & j.cmplt(N_GRID - 1), { // pressure gradient - let f_p = make_float2( - p.read(index(make_uint2(i.uint() + 1, j.uint()))) - - p.read(index(make_uint2(i.uint() - 1, j.uint()))), - p.read(index(make_uint2(i.uint(), j.uint() + 1))) - - p.read(index(make_uint2(i.uint(), j.uint() - 1))), + let f_p = Float2::expr( + p.read(index(Uint2::expr(i.uint() + 1, j.uint()))) + - p.read(index(Uint2::expr(i.uint() - 1, j.uint()))), + p.read(index(Uint2::expr(i.uint(), j.uint() + 1))) + - p.read(index(Uint2::expr(i.uint(), j.uint() - 1))), ) * 0.5f32; u.write(ij, u.read(ij) - f_p); @@ -186,7 +182,7 @@ fn main() { let ij = index(coord); // gravity - let f_g = make_float2(-90.8f32, 0.0f32) * rho.read(ij); + let f_g = Float2::expr(-90.8f32, 0.0f32) * rho.read(ij); // integrate u.write(ij, u.read(ij) + dt * f_g); @@ -201,7 +197,7 @@ fn main() { let i = coord.x().int(); let j = coord.y().int(); let ij = index(coord); - let d = make_float2((i - N_GRID / 2).float(), (j - N_GRID / 2).float()).length(); + let d = Float2::expr((i - N_GRID / 2).float(), (j - N_GRID / 2).float()).length(); let radius = 5.0f32; if_!(d.cmplt(radius), { @@ -212,8 +208,8 @@ fn main() { let init_grid = device.create_kernel_async::(&|| { let idx = index(dispatch_id().xy()); - u0.var().write(idx, make_float2(0.0f32, 0.0f32)); - u1.var().write(idx, make_float2(0.0f32, 0.0f32)); + u0.var().write(idx, Float2::expr(0.0f32, 0.0f32)); + u1.var().write(idx, Float2::expr(0.0f32, 0.0f32)); rho0.var().write(idx, 0.0f32); rho1.var().write(idx, 0.0f32); @@ -234,8 +230,8 @@ fn main() { let ij = index(coord); let value = rho0.var().read(ij); display.var().write( - make_uint2(coord.x(), (N_GRID - 1) as u32 - coord.y()), - make_float4(value, 0.0f32, 0.0f32, 1.0f32), + Uint2::expr(coord.x(), (N_GRID - 1) as u32 - coord.y()), + Float4::expr(value, 0.0f32, 0.0f32, 1.0f32), ); }); diff --git a/luisa_compute/examples/mpm.rs b/luisa_compute/examples/mpm.rs index 08b2420..c6e0d71 100644 --- a/luisa_compute/examples/mpm.rs +++ b/luisa_compute/examples/mpm.rs @@ -1,16 +1,12 @@ #![allow(non_snake_case)] -use std::{env::current_exe, time::Instant}; +use std::env::current_exe; +use std::time::Instant; -use luisa::init_logger; -#[allow(unused_imports)] use luisa::prelude::*; -use luisa::*; use luisa_compute as luisa; use rand::Rng; -use winit::{ - event::{Event, WindowEvent}, - event_loop::{ControlFlow, EventLoop}, -}; +use winit::event::{Event, WindowEvent}; +use winit::event_loop::{ControlFlow, EventLoop}; const N_GRID: usize = 128; const N_STEPS: usize = 50; @@ -26,7 +22,7 @@ const E: f32 = 400.0f32; const RESOLUTION: u32 = 512; fn main() { - init_logger(); + luisa::init_logger(); std::env::set_var("WINIT_UNIX_BACKEND", "x11"); let args: Vec = std::env::args().collect(); if args.len() > 2 { @@ -87,8 +83,8 @@ fn main() { let index = |xy: Expr| -> Expr { let p = xy.clamp( - make_uint2(0, 0), - make_uint2(N_GRID as u32 - 1, N_GRID as u32 - 1), + Uint2::expr(0, 0), + Uint2::expr(N_GRID as u32 - 1, N_GRID as u32 - 1), ); p.x() + p.y() * N_GRID as u32 }; @@ -113,11 +109,11 @@ fn main() { ]; let stress = -4.0f32 * DT * E * P_VOL * (J.var().read(p) - 1.0f32) / (DX * DX); let affine = - Expr::::eye(make_float2(stress, stress)) + P_MASS as f32 * C.var().read(p); + Expr::::eye(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p); let vp = v.var().read(p); for ii in 0..9 { let (i, j) = (ii % 3, ii / 3); - let offset = make_int2(i as i32, j as i32); + let offset = Int2::expr(i as i32, j as i32); let dpos = (offset.float() - fx) * DX; let weight = w[i].x() * w[j].y(); let vadd = weight * (P_MASS * vp + affine * dpos); @@ -128,11 +124,11 @@ fn main() { } }); - let simulate_grid = device.create_kernel_async::(&|| { + let simulate_grid = device.create_kernel_async::(&track!(|| { let coord = dispatch_id().xy(); let i = index(coord); - let v = var!(Float2); - v.store(make_float2( + let v = Var::::zeroed(); + v.store(Float2::expr( grid_v.var().read(i * 2u32), grid_v.var().read(i * 2u32 + 1u32), )); @@ -142,20 +138,20 @@ fn main() { let vx = v.load().x(); let vy = v.load().y() - DT * GRAVITY; let vx = select( - (coord.x().cmplt(BOUND) & vx.cmplt(0.0f32)) - | (coord.x() + BOUND).cmpgt(N_GRID as u32) & vx.cmpgt(0.0f32), + coord.x() < BOUND && (vx < 0.0f32) + || coord.x() + BOUND > N_GRID as u32 && (vx > 0.0f32), 0.0f32.into(), vx, ); let vy = select( - (coord.y().cmplt(BOUND) & vy.cmplt(0.0f32)) - | (coord.y() + BOUND).cmpgt(N_GRID as u32) & vy.cmpgt(0.0f32), + coord.y() < BOUND && (vy < 0.0f32) + || coord.y() + BOUND > N_GRID as u32 && (vy > 0.0f32), 0.0f32.into(), vy, ); grid_v.var().write(i * 2, vx); grid_v.var().write(i * 2 + 1, vy); - }); + })); let grid_to_point = device.create_kernel_async::(&|| { let p = dispatch_id().x(); @@ -168,17 +164,17 @@ fn main() { 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), ]; - let new_v = var!(Float2); - let new_C = var!(Mat2); - new_v.store(make_float2(0.0f32, 0.0f32)); - new_C.store(make_float2x2(make_float2(0., 0.), make_float2(0., 0.))); + let new_v = Var::::zeroed(); + let new_C = Var::::zeroed(); + new_v.store(Float2::expr(0.0f32, 0.0f32)); + new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.))); for ii in 0..9 { let (i, j) = (ii % 3, ii / 3); - let offset = make_int2(i as i32, j as i32); + let offset = Int2::expr(i as i32, j as i32); let dpos = (offset.float() - fx) * DX; let weight = w[i].x() * w[j].y(); let idx = index((base + offset).uint()); - let g_v = make_float2( + let g_v = Float2::expr( grid_v.var().read(idx * 2u32), grid_v.var().read(idx * 2u32 + 1u32), ); @@ -195,14 +191,14 @@ fn main() { let clear_display = device.create_kernel_async::(&|| { display.var().write( dispatch_id().xy(), - make_float4(0.1f32, 0.2f32, 0.3f32, 1.0f32), + Float4::expr(0.1f32, 0.2f32, 0.3f32, 1.0f32), ); }); let draw_particles = device.create_kernel_async::(&|| { let p = dispatch_id().x(); for i in -1..=1 { for j in -1..=1 { - let pos = (x.var().read(p) * RESOLUTION as f32).int() + make_int2(i, j); + let pos = (x.var().read(p) * RESOLUTION as f32).int() + Int2::expr(i, j); if_!( pos.x().cmpge(0i32) & pos.x().cmplt(RESOLUTION as i32) @@ -210,8 +206,8 @@ fn main() { & pos.y().cmplt(RESOLUTION as i32), { display.var().write( - make_uint2(pos.x().uint(), RESOLUTION - 1u32 - pos.y().uint()), - make_float4(0.4f32, 0.6f32, 0.6f32, 1.0f32), + Uint2::expr(pos.x().uint(), RESOLUTION - 1u32 - pos.y().uint()), + Float4::expr(0.4f32, 0.6f32, 0.6f32, 1.0f32), ); } ); diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index aa52bcc..16c5c25 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -2,16 +2,12 @@ use image::Rgb; use luisa_compute_api_types::StreamTag; use rand::Rng; use std::env::current_exe; -use std::ops::DerefMut; use std::time::Instant; -use winit::event::Event as WinitEvent; -use winit::event::WindowEvent; -use winit::event_loop::{ControlFlow, EventLoop}; +use winit::event::{Event as WinitEvent, WindowEvent}; +use winit::event_loop::EventLoop; -#[allow(unused_imports)] use luisa::prelude::*; -use luisa::rtx::{offset_ray_origin, Accel, AccelVar, Index, Ray}; -use luisa::{Expr, Float3, Value}; +use luisa::rtx::{offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray}; use luisa_compute as luisa; #[derive(Value, Clone, Copy)] @@ -175,8 +171,7 @@ f -4 -3 -2 -1"; const SPP_PER_DISPATCH: u32 = 32u32; fn main() { - use luisa::*; - init_logger_verbose(); + luisa::init_logger_verbose(); std::env::set_var("WINIT_UNIX_BACKEND", "x11"); let args: Vec = std::env::args().collect(); @@ -253,7 +248,7 @@ fn main() { accel: AccelVar, resolution: Expr| { set_block_size([16u32, 16u32, 1u32]); - let cbox_materials = const_([ + let cbox_materials = ([ Float3::new(0.725f32, 0.710f32, 0.680f32), // floor Float3::new(0.725f32, 0.710f32, 0.680f32), // ceiling Float3::new(0.725f32, 0.710f32, 0.680f32), // back wall @@ -262,7 +257,7 @@ fn main() { Float3::new(0.725f32, 0.710f32, 0.680f32), // short box Float3::new(0.725f32, 0.710f32, 0.680f32), // tall box Float3::new(0.000f32, 0.000f32, 0.000f32), // light - ]); + ]).expr(); let lcg = |state: Var| -> Expr { let lcg = create_static_callable::)-> Expr>(|state:Var|{ @@ -275,7 +270,7 @@ fn main() { }; let make_ray = |o: Expr, d: Expr, tmin: Expr, tmax: Expr| -> Expr { - struct_!(rtx::Ray { + struct_!(Ray { orig: o.into(), tmin: tmin, dir:d.into(), @@ -285,10 +280,10 @@ fn main() { let generate_ray = |p: Expr| -> Expr { const FOV: f32 = 27.8f32 * std::f32::consts::PI / 180.0f32; - let origin = make_float3(-0.01f32, 0.995f32, 5.0f32); + let origin = Float3::expr(-0.01f32, 0.995f32, 5.0f32); let pixel = origin - + make_float3( + + Float3::expr( p.x() * f32::tan(0.5f32 * FOV), p.y() * f32::tan(0.5f32 * FOV), -1.0f32, @@ -304,9 +299,9 @@ fn main() { let make_onb = |normal: Expr| -> Expr { let binormal = if_!( normal.x().abs().cmpgt(normal.z().abs()), { - make_float3(-normal.y(), normal.x(), 0.0f32) + Float3::expr(-normal.y(), normal.x(), 0.0f32) }, else { - make_float3(0.0f32, -normal.z(), normal.y()) + Float3::expr(0.0f32, -normal.z(), normal.y()) } ); let tangent = binormal.cross(normal).normalize(); @@ -316,39 +311,39 @@ fn main() { let cosine_sample_hemisphere = |u: Expr| { let r = u.x().sqrt(); let phi = 2.0f32 * std::f32::consts::PI * u.y(); - make_float3(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x()).sqrt()) + Float3::expr(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x()).sqrt()) }; let coord = dispatch_id().xy(); let frame_size = resolution.x().min(resolution.y()).float(); - let state = var!(u32); + let state = Var::::zeroed(); state.store(seed_image.read(coord)); let rx = lcg(state); let ry = lcg(state); - let pixel = (coord.float() + make_float2(rx, ry)) / frame_size * 2.0f32 - 1.0f32; + let pixel = (coord.float() + Float2::expr(rx, ry)) / frame_size * 2.0f32 - 1.0f32; - let radiance = var!(Float3); - radiance.store(make_float3(0.0f32, 0.0f32, 0.0f32)); + let radiance = Var::::zeroed(); + radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); for_range(0..SPP_PER_DISPATCH as u32, |_| { - let init_ray = generate_ray(pixel * make_float2(1.0f32, -1.0f32)); - let ray = var!(Ray); + let init_ray = generate_ray(pixel * Float2::expr(1.0f32, -1.0f32)); + let ray = Var::::zeroed(); ray.store(init_ray); - let beta = var!(Float3); - beta.store(make_float3(1.0f32, 1.0f32, 1.0f32)); - let pdf_bsdf = var!(f32); + let beta = Var::::zeroed(); + beta.store(Float3::expr(1.0f32, 1.0f32, 1.0f32)); + let pdf_bsdf = Var::::zeroed(); pdf_bsdf.store(0.0f32); - let light_position = make_float3(-0.24f32, 1.98f32, 0.16f32); - let light_u = make_float3(-0.24f32, 1.98f32, -0.22f32) - light_position; - let light_v = make_float3(0.23f32, 1.98f32, 0.16f32) - light_position; - let light_emission = make_float3(17.0f32, 12.0f32, 4.0f32); + let light_position = Float3::expr(-0.24f32, 1.98f32, 0.16f32); + let light_u = Float3::expr(-0.24f32, 1.98f32, -0.22f32) - light_position; + let light_v = Float3::expr(0.23f32, 1.98f32, 0.16f32) - light_position; + let light_emission = Float3::expr(17.0f32, 12.0f32, 4.0f32); let light_area = light_u.cross(light_v).length(); let light_normal = light_u.cross(light_v).normalize(); - let depth = var!(u32); + let depth = Var::::zeroed(); while_!(depth.load().cmplt(10u32), { let hit = accel.trace_closest(ray); @@ -414,13 +409,13 @@ fn main() { let onb = make_onb(n); let ux = lcg(state); let uy = lcg(state); - let new_direction = onb.to_world(cosine_sample_hemisphere(make_float2(ux, uy))); + let new_direction = onb.to_world(cosine_sample_hemisphere(Float2::expr(ux, uy))); *ray.get_mut() = make_ray(pp, new_direction, 0.0f32.into(), std::f32::MAX.into()); *beta.get_mut() *= albedo; pdf_bsdf.store(cos_wi * std::f32::consts::FRAC_1_PI); // russian roulette - let l = make_float3(0.212671f32, 0.715160f32, 0.072169f32).dot(*beta); + let l = Float3::expr(0.212671f32, 0.715160f32, 0.072169f32).dot(*beta); if_!(l.cmpeq(0.0f32), { break_(); }); let q = l.max(0.05f32); let r = lcg(state); @@ -432,28 +427,29 @@ fn main() { }); radiance.store(radiance.load() / SPP_PER_DISPATCH as f32); seed_image.write(coord, *state); - if_!(radiance.load().is_nan().any(), { radiance.store(make_float3(0.0f32, 0.0f32, 0.0f32)); }); + if_!(radiance.load().is_nan().any(), { radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); }); let radiance = radiance.load().clamp(0.0f32, 30.0f32); let old = image.read(dispatch_id().xy()); let spp = old.w(); let radiance = radiance + old.xyz(); - image.write(dispatch_id().xy(), make_float4(radiance.x(), radiance.y(), radiance.z(), spp + 1.0f32)); + image.write(dispatch_id().xy(), Float4::expr(radiance.x(), radiance.y(), radiance.z(), spp + 1.0f32)); }, ) ; - let display = device.create_kernel_async::, Tex2d)>(&|acc, display| { - set_block_size([16, 16, 1]); - let coord = dispatch_id().xy(); - let radiance = acc.read(coord); - let spp = radiance.w(); - let radiance = radiance.xyz() / spp; - - // workaround a rust-analyzer bug - let r = 1.055f32 * radiance.powf(1.0 / 2.4) - 0.055; - - let srgb = Float3Expr::select(radiance.cmplt(0.0031308), radiance * 12.92, r); - display.write(coord, make_float4(srgb.x(), srgb.y(), srgb.z(), 1.0f32)); - }); + let display = + device.create_kernel_async::, Tex2d)>(&|acc, display| { + set_block_size([16, 16, 1]); + let coord = dispatch_id().xy(); + let radiance = acc.read(coord); + let spp = radiance.w(); + let radiance = radiance.xyz() / spp; + + // workaround a rust-analyzer bug + let r = 1.055f32 * radiance.powf(1.0 / 2.4) - 0.055; + + let srgb = Expr::::select(radiance.cmplt(0.0031308), radiance * 12.92, r); + display.write(coord, Float4::expr(srgb.x(), srgb.y(), srgb.z(), 1.0f32)); + }); let img_w = 1024; let img_h = 1024; let acc_img = device.create_tex2d::(PixelStorage::Float4, img_w, img_h, 1); diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index 38f0ad5..b4bffa8 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -2,18 +2,16 @@ use image::Rgb; use luisa_compute_api_types::StreamTag; use rand::Rng; use std::env::current_exe; -use std::ops::{BitAnd, DerefMut}; use std::time::Instant; -use winit::event::Event as WinitEvent; -use winit::event::WindowEvent; -use winit::event_loop::{ControlFlow, EventLoop}; +use winit::event::{Event as WinitEvent, WindowEvent}; +use winit::event_loop::EventLoop; -#[allow(unused_imports)] use luisa::prelude::*; -use luisa::rtx::{offset_ray_origin, Accel, AccelVar, Index, Ray}; -use luisa::{Expr, Float3, Value}; +use luisa::rtx::{ + offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayQuery, + TriangleCandidate, +}; use luisa_compute as luisa; -use luisa_compute::rtx::{RayQuery, TriangleCandidate}; #[derive(Value, Clone, Copy)] #[repr(C)] @@ -176,8 +174,7 @@ f -4 -3 -2 -1"; const SPP_PER_DISPATCH: u32 = 1u32; fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); std::env::set_var("WINIT_UNIX_BACKEND", "x11"); let args: Vec = std::env::args().collect(); @@ -259,7 +256,7 @@ fn main() { accel: AccelVar, resolution: Expr| { set_block_size([16u32, 16u32, 1u32]); - let cbox_materials = const_([ + let cbox_materials = [ Float3::new(0.725f32, 0.710f32, 0.680f32), // floor Float3::new(0.725f32, 0.710f32, 0.680f32), // ceiling Float3::new(0.725f32, 0.710f32, 0.680f32), // back wall @@ -268,7 +265,7 @@ fn main() { Float3::new(0.725f32, 0.710f32, 0.680f32), // short box Float3::new(0.725f32, 0.710f32, 0.680f32), // tall box Float3::new(0.000f32, 0.000f32, 0.000f32), // light - ]); + ].expr(); let lcg = |state: Var| -> Expr { let lcg = create_static_callable::)->Expr>(|state: Var| { @@ -281,7 +278,7 @@ fn main() { }; let make_ray = |o: Expr, d: Expr, tmin: Expr, tmax: Expr| -> Expr { - struct_!(rtx::Ray { + struct_!(Ray { orig: o.into(), tmin: tmin, dir:d.into(), @@ -291,10 +288,10 @@ fn main() { let generate_ray = |p: Expr| -> Expr { const FOV: f32 = 27.8f32 * std::f32::consts::PI / 180.0f32; - let origin = make_float3(-0.01f32, 0.995f32, 5.0f32); + let origin = Float3::expr(-0.01f32, 0.995f32, 5.0f32); let pixel = origin - + make_float3( + + Float3::expr( p.x() * f32::tan(0.5f32 * FOV), p.y() * f32::tan(0.5f32 * FOV), -1.0f32, @@ -310,9 +307,9 @@ fn main() { let make_onb = |normal: Expr| -> Expr { let binormal = if_!( normal.x().abs().cmpgt(normal.z().abs()), { - make_float3(-normal.y(), normal.x(), 0.0f32) + Float3::expr(-normal.y(), normal.x(), 0.0f32) }, else { - make_float3(0.0f32, -normal.z(), normal.y()) + Float3::expr(0.0f32, -normal.z(), normal.y()) } ); let tangent = binormal.cross(normal).normalize(); @@ -322,53 +319,53 @@ fn main() { let cosine_sample_hemisphere = |u: Expr| { let r = u.x().sqrt(); let phi = 2.0f32 * std::f32::consts::PI * u.y(); - make_float3(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x()).sqrt()) + Float3::expr(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x()).sqrt()) }; let coord = dispatch_id().xy(); let frame_size = resolution.x().min(resolution.y()).float(); - let state = var!(u32); + let state = Var::::zeroed(); state.store(seed_image.read(coord)); let rx = lcg(state); let ry = lcg(state); - let pixel = (coord.float() + make_float2(rx, ry)) / frame_size * 2.0f32 - 1.0f32; + let pixel = (coord.float() + Float2::expr(rx, ry)) / frame_size * 2.0f32 - 1.0f32; - let radiance = var!(Float3); - radiance.store(make_float3(0.0f32, 0.0f32, 0.0f32)); + let radiance = Var::::zeroed(); + radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); for_range(0..SPP_PER_DISPATCH as u32, |_| { - let init_ray = generate_ray(pixel * make_float2(1.0f32, -1.0f32)); - let ray = var!(Ray); + let init_ray = generate_ray(pixel * Float2::expr(1.0f32, -1.0f32)); + let ray = Var::::zeroed(); ray.store(init_ray); - let beta = var!(Float3); - beta.store(make_float3(1.0f32, 1.0f32, 1.0f32)); - let pdf_bsdf = var!(f32); + let beta = Var::::zeroed(); + beta.store(Float3::expr(1.0f32, 1.0f32, 1.0f32)); + let pdf_bsdf = Var::::zeroed(); pdf_bsdf.store(0.0f32); - let light_position = make_float3(-0.24f32, 1.98f32, 0.16f32); - let light_u = make_float3(-0.24f32, 1.98f32, -0.22f32) - light_position; - let light_v = make_float3(0.23f32, 1.98f32, 0.16f32) - light_position; - let light_emission = make_float3(17.0f32, 12.0f32, 4.0f32); + let light_position = Float3::expr(-0.24f32, 1.98f32, 0.16f32); + let light_u = Float3::expr(-0.24f32, 1.98f32, -0.22f32) - light_position; + let light_v = Float3::expr(0.23f32, 1.98f32, 0.16f32) - light_position; + let light_emission = Float3::expr(17.0f32, 12.0f32, 4.0f32); let light_area = light_u.cross(light_v).length(); let light_normal = light_u.cross(light_v).normalize(); let filter = |c: &TriangleCandidate| { - let valid = var!(bool, true); + let valid = true.var(); if_!(c.inst().cmpeq(5u32), { valid.store((c.bary().y() * 6.0f32).fract().cmplt(0.6f32)); }); if_!(c.inst().cmpeq(6u32), { valid.store((c.bary().y() * 5.0f32).fract().cmplt(0.5f32)); }); valid.load() }; - let depth = var!(u32); + let depth = Var::::zeroed(); while_!(depth.load().cmplt(10u32), { // let hit = accel.trace_closest(ray); let hit = accel.query_all(ray, 255, RayQuery { on_triangle_hit: |c: TriangleCandidate| { if_!(filter(&c), { c.commit(); }); }, - on_procedural_hit: |c| {} + on_procedural_hit: |_c| {} }); if_!(hit.miss(), { @@ -386,9 +383,9 @@ fn main() { let p2: Expr = vertex_buffer.read(triangle.z()).into(); let m = accel.instance_transform(hit.inst_id()); let p = p0 * (1.0f32 - hit.bary().x() - hit.bary().y()) + p1 * hit.bary().x() + p2 * hit.bary().y(); - let p = (m * make_float4(p.x(), p.y(), p.z(), 1.0f32)).xyz(); + let p = (m * Float4::expr(p.x(), p.y(), p.z(), 1.0f32)).xyz(); let n = (p1 - p0).cross(p2 - p0); - let n = (m * make_float4(n.x(), n.y(), n.z(), 0.0f32)).xyz().normalize(); + let n = (m * Float4::expr(n.x(), n.y(), n.z(), 0.0f32)).xyz().normalize(); let origin: Expr = ray.load().orig().into(); let direction: Expr = ray.load().dir().into(); @@ -443,13 +440,13 @@ fn main() { let onb = make_onb(n); let ux = lcg(state); let uy = lcg(state); - let new_direction = onb.to_world(cosine_sample_hemisphere(make_float2(ux, uy))); + let new_direction = onb.to_world(cosine_sample_hemisphere(Float2::expr(ux, uy))); *ray.get_mut() = make_ray(pp, new_direction, 0.0f32.into(), std::f32::MAX.into()); *beta.get_mut() *= albedo; pdf_bsdf.store(cos_wi.abs() * std::f32::consts::FRAC_1_PI); // russian roulette - let l = make_float3(0.212671f32, 0.715160f32, 0.072169f32).dot(*beta); + let l = Float3::expr(0.212671f32, 0.715160f32, 0.072169f32).dot(*beta); if_!(l.cmpeq(0.0f32), { break_(); }); let q = l.max(0.05f32); let r = lcg(state); @@ -461,28 +458,29 @@ fn main() { }); radiance.store(radiance.load() / SPP_PER_DISPATCH as f32); seed_image.write(coord, *state); - if_!(radiance.load().is_nan().any(), { radiance.store(make_float3(0.0f32, 0.0f32, 0.0f32)); }); + if_!(radiance.load().is_nan().any(), { radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); }); let radiance = radiance.load().clamp(0.0f32, 30.0f32); let old = image.read(dispatch_id().xy()); let spp = old.w(); let radiance = radiance + old.xyz(); - image.write(dispatch_id().xy(), make_float4(radiance.x(), radiance.y(), radiance.z(), spp + 1.0f32)); + image.write(dispatch_id().xy(), Float4::expr(radiance.x(), radiance.y(), radiance.z(), spp + 1.0f32)); }, ) ; - let display = device.create_kernel_async::, Tex2d)>(&|acc, display| { - set_block_size([16, 16, 1]); - let coord = dispatch_id().xy(); - let radiance = acc.read(coord); - let spp = radiance.w(); - let radiance = radiance.xyz() / spp; - - // workaround a rust-analyzer bug - let r = 1.055f32 * radiance.powf(1.0 / 2.4) - 0.055; - - let srgb = Float3Expr::select(radiance.cmplt(0.0031308), radiance * 12.92, r); - display.write(coord, make_float4(srgb.x(), srgb.y(), srgb.z(), 1.0f32)); - }); + let display = + device.create_kernel_async::, Tex2d)>(&|acc, display| { + set_block_size([16, 16, 1]); + let coord = dispatch_id().xy(); + let radiance = acc.read(coord); + let spp = radiance.w(); + let radiance = radiance.xyz() / spp; + + // workaround a rust-analyzer bug + let r = 1.055f32 * radiance.powf(1.0 / 2.4) - 0.055; + + let srgb = Expr::::select(radiance.cmplt(0.0031308), radiance * 12.92, r); + display.write(coord, Float4::expr(srgb.x(), srgb.y(), srgb.z(), 1.0f32)); + }); let img_w = 1024; let img_h = 1024; let acc_img = device.create_tex2d::(PixelStorage::Float4, img_w, img_h, 1); diff --git a/luisa_compute/examples/polymorphism.rs b/luisa_compute/examples/polymorphism.rs index 2b447c2..c3d396d 100644 --- a/luisa_compute/examples/polymorphism.rs +++ b/luisa_compute/examples/polymorphism.rs @@ -1,13 +1,12 @@ use std::env::current_exe; use std::f32::consts::PI; +use luisa::lang::poly::*; use luisa::prelude::*; -use luisa::Value; -use luisa::{impl_polymorphic, Float}; use luisa_compute as luisa; trait Area { - fn area(&self) -> Float; + fn area(&self) -> Expr; } #[derive(Value, Clone, Copy)] #[repr(C)] @@ -15,7 +14,7 @@ pub struct Circle { radius: f32, } impl Area for CircleExpr { - fn area(&self) -> Float { + fn area(&self) -> Expr { PI * self.radius() * self.radius() } } @@ -26,13 +25,12 @@ pub struct Square { side: f32, } impl Area for SquareExpr { - fn area(&self) -> Float { + fn area(&self) -> Expr { self.side() * self.side() } } impl_polymorphic!(Area, Square); fn main() { - use luisa::*; let ctx = Context::new(current_exe().unwrap()); let device = ctx.create_device("cpu"); let circles = device.create_buffer::(2); diff --git a/luisa_compute/examples/polymorphism_advanced.rs b/luisa_compute/examples/polymorphism_advanced.rs index d9cbc54..18f7808 100644 --- a/luisa_compute/examples/polymorphism_advanced.rs +++ b/luisa_compute/examples/polymorphism_advanced.rs @@ -1,9 +1,8 @@ use std::env::current_exe; use std::f32::consts::PI; +use luisa::lang::poly::*; use luisa::prelude::*; -use luisa::Value; -use luisa::{impl_polymorphic, lang::*}; use luisa_compute as luisa; #[derive(Clone, Hash, PartialEq, Eq, Debug)] diff --git a/luisa_compute/examples/printer.rs b/luisa_compute/examples/printer.rs index 9ef91cc..752c206 100644 --- a/luisa_compute/examples/printer.rs +++ b/luisa_compute/examples/printer.rs @@ -1,9 +1,12 @@ use std::env::current_exe; +use luisa::prelude::*; +use luisa::printer::*; + use luisa_compute as luisa; + fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); let args: Vec = std::env::args().collect(); assert!( args.len() <= 2, diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 9eb4896..6b08d23 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -1,17 +1,13 @@ use std::env::current_exe; use image::Rgb; -#[allow(unused_imports)] use luisa::prelude::*; -use luisa::rtx::{Aabb, ProceduralCandidate, RayQuery, TriangleCandidate}; -use luisa::Float3; -use luisa::{derive::*, PackedFloat3}; -use luisa_compute as luisa; -use winit::event::Event as WinitEvent; -use winit::{ - event::WindowEvent, - event_loop::{ControlFlow, EventLoop}, +use luisa::rtx::{ + Aabb, AccelBuildRequest, AccelOption, ProceduralCandidate, RayExpr, RayQuery, TriangleCandidate, }; +use luisa_compute as luisa; +use winit::event::{Event as WinitEvent, WindowEvent}; +use winit::event_loop::{ControlFlow, EventLoop}; #[derive(Copy, Clone, Debug, Value)] #[repr(C)] @@ -37,8 +33,7 @@ impl Sphere { } fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); std::env::set_var("WINIT_UNIX_BACKEND", "x11"); @@ -79,7 +74,7 @@ fn main() { radius: 0.8, }, ]; - let aabb = device.create_buffer_from_slice::(&[ + let aabb = device.create_buffer_from_slice::(&[ spheres[0].aabb(), spheres[1].aabb(), spheres[2].aabb(), @@ -111,21 +106,21 @@ fn main() { let rt_kernel = device.create_kernel::(&|| { let accel = accel.var(); let px = dispatch_id().xy(); - let xy = px.float() / make_float2(img_w as f32, img_h as f32); + let xy = px.float() / Float2::expr(img_w as f32, img_h as f32); let xy = 2.0 * xy - 1.0; - let o = make_float3(0.0, 0.0, -2.0); - let d = make_float3(xy.x(), xy.y(), 0.0) - o; + let o = Float3::expr(0.0, 0.0, -2.0); + let d = Float3::expr(xy.x(), xy.y(), 0.0) - o; let d = d.normalize(); - let ray = rtx::RayExpr::new(o + const_(translate), 1e-3, d, 1e9); + let ray = RayExpr::new(o + translate.expr(), 1e-3, d, 1e9); let hit = accel.query_all( ray, 255, RayQuery { on_triangle_hit: |candidate: TriangleCandidate| { let bary = candidate.bary(); - let uvw = make_float3(1.0 - bary.x() - bary.y(), bary.x(), bary.y()); + let uvw = Float3::expr(1.0 - bary.x() - bary.y(), bary.x(), bary.y()); let t = candidate.committed_ray_t(); - if_!(px.cmpeq(make_uint2(400, 400)).all(), { + if_!(px.cmpeq(Uint2::expr(400, 400)).all(), { debug_hit_t.write(0, t); debug_hit_t.write(1, candidate.ray().tmax()); }); @@ -144,14 +139,14 @@ fn main() { let sphere = spheres.var().read(prim); let o = ray.orig().unpack(); let d = ray.dir().unpack(); - let t = var!(f32); + let t = Var::::zeroed(); - for_range(const_(0i32)..const_(100i32), |_| { - let dist = (o + d * t.load() - (sphere.center() + const_(translate))) + for_range(0i32.expr()..100i32.expr(), |_| { + let dist = (o + d * t.load() - (sphere.center() + translate.expr())) .length() - sphere.radius(); if_!(dist.cmplt(0.001), { - if_!(px.cmpeq(make_uint2(400, 400)).all(), { + if_!(px.cmpeq(Uint2::expr(400, 400)).all(), { debug_hit_t.write(2, *t); debug_hit_t.write(3, candidate.ray().tmax()); }); @@ -170,7 +165,7 @@ fn main() { hit.triangle_hit(), { let bary = hit.bary(); - let uvw = make_float3(1.0 - bary.x() - bary.y(), bary.x(), bary.y()); + let uvw = Float3::expr(1.0 - bary.x() - bary.y(), bary.x(), bary.y()); uvw }, else, @@ -184,19 +179,19 @@ fn main() { + ray.dir().unpack() * hit.committed_ray_t() - sphere.center()) .normalize(); - let light_dir = make_float3(1.0, 0.6, -0.2).normalize(); - let light = make_float3(1.0, 1.0, 1.0); - let ambient = make_float3(0.1, 0.1, 0.1); + let light_dir = Float3::expr(1.0, 0.6, -0.2).normalize(); + let light = Float3::expr(1.0, 1.0, 1.0); + let ambient = Float3::expr(0.1, 0.1, 0.1); let diffuse = light * normal.dot(light_dir).max(0.0); let color = ambient + diffuse; color }, else, - { make_float3(0.0, 0.0, 0.0) } + { Float3::expr(0.0, 0.0, 0.0) } ) } ); - img.write(px, make_float4(color.x(), color.y(), color.z(), 1.0)); + img.write(px, Float4::expr(color.x(), color.y(), color.z(), 1.0)); }); let event_loop = EventLoop::new(); let window = winit::window::WindowBuilder::new() diff --git a/luisa_compute/examples/raytracing.rs b/luisa_compute/examples/raytracing.rs index aca84d1..8a88303 100644 --- a/luisa_compute/examples/raytracing.rs +++ b/luisa_compute/examples/raytracing.rs @@ -1,17 +1,13 @@ use std::env::current_exe; use image::Rgb; -#[allow(unused_imports)] use luisa::prelude::*; +use luisa::rtx::{AccelBuildRequest, AccelOption, RayExpr}; use luisa_compute as luisa; -use winit::{ - event::{WindowEvent}, - event_loop::{ControlFlow, EventLoop}, -}; -use winit::event::Event as WinitEvent; +use winit::event::{Event as WinitEvent, WindowEvent}; +use winit::event_loop::{ControlFlow, EventLoop}; fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); std::env::set_var("WINIT_UNIX_BACKEND", "x11"); @@ -46,20 +42,20 @@ fn main() { let rt_kernel = device.create_kernel::(&|| { let accel = accel.var(); let px = dispatch_id().xy(); - let xy = px.float() / make_float2(img_w as f32, img_h as f32); + let xy = px.float() / Float2::expr(img_w as f32, img_h as f32); let xy = 2.0 * xy - 1.0; - let o = make_float3(0.0, 0.0, -1.0); - let d = make_float3(xy.x(), xy.y(), 0.0) - o; + let o = Float3::expr(0.0, 0.0, -1.0); + let d = Float3::expr(xy.x(), xy.y(), 0.0) - o; let d = d.normalize(); - let ray = rtx::RayExpr::new(o, 1e-3, d, 1e9); + let ray = RayExpr::new(o, 1e-3, d, 1e9); let hit = accel.trace_closest(ray); let img = img.view(0).var(); let color = select( hit.valid(), - make_float3(hit.u(), hit.v(), 1.0), - make_float3(0.0, 0.0, 0.0), + Float3::expr(hit.u(), hit.v(), 1.0), + Float3::expr(0.0, 0.0, 0.0), ); - img.write(px, make_float4(color.x(), color.y(), color.z(), 1.0)); + img.write(px, Float4::expr(color.x(), color.y(), color.z(), 1.0)); }); let event_loop = EventLoop::new(); let window = winit::window::WindowBuilder::new() diff --git a/luisa_compute/examples/sdf_renderer.rs b/luisa_compute/examples/sdf_renderer.rs index 7a706bb..0eb7f58 100644 --- a/luisa_compute/examples/sdf_renderer.rs +++ b/luisa_compute/examples/sdf_renderer.rs @@ -1,8 +1,6 @@ use std::env::current_exe; -use luisa::math::*; use luisa::prelude::*; -use luisa::Value; use luisa_compute as luisa; #[derive(Copy, Clone, Debug, Value)] @@ -13,7 +11,6 @@ pub struct Sphere { } fn main() { - use luisa::*; let args: Vec = std::env::args().collect(); assert!( args.len() <= 2, @@ -37,10 +34,10 @@ fn main() { device.create_kernel::, Buffer)>( &|buf_x: BufferVar, spheres: BufferVar| { let tid = dispatch_id().x(); - let o = make_float3(0.0, 0.0, -2.0); - let d = make_float3(0.0, 0.0, 1.0); + let o = Float3::expr(0.0, 0.0, -2.0); + let d = Float3::expr(0.0, 0.0, 1.0); let sphere = spheres.read(0); - let t = var!(f32); + let t = Var::::zeroed(); while_!(t.load().cmplt(10.0), { let p = o + d * t.load(); let d = (p - sphere.center()).length() - sphere.radius(); diff --git a/luisa_compute/examples/shadertoy.rs b/luisa_compute/examples/shadertoy.rs index 534927e..3b1ea29 100644 --- a/luisa_compute/examples/shadertoy.rs +++ b/luisa_compute/examples/shadertoy.rs @@ -2,8 +2,7 @@ use luisa::prelude::*; use luisa_compute as luisa; use std::env::current_exe; fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); std::env::set_var("WINIT_UNIX_BACKEND", "x11"); @@ -22,31 +21,30 @@ fn main() { }); let palette = device.create_callable::) -> Expr>(&|d| { - make_float3(0.2, 0.7, 0.9).lerp(make_float3(1.0, 0.0, 1.0), Float3Expr::splat(d)) + Float3::expr(0.2, 0.7, 0.9).lerp(Float3::expr(1.0, 0.0, 1.0), Expr::::splat(d)) + }); + let rotate = device.create_callable::, Expr) -> Expr>(&|p, a| { + let c = a.cos(); + let s = a.sin(); + Float2::expr(p.dot(Float2::expr(c, s)), p.dot(Float2::expr(-s, c))) }); - let rotate = - device.create_callable::, Expr) -> Expr>(&|mut p, a| { - let c = a.cos(); - let s = a.sin(); - make_float2(p.dot(make_float2(c, s)), p.dot(make_float2(-s, c))) - }); let map = device.create_callable::, Expr) -> Expr>(&|mut p, time| { - for i in 0..8 { + for _i in 0..8 { let t = time * 0.2; let r = rotate.call(p.xz(), t); - p = make_float3(r.x(), r.y(), p.y()).xzy(); + p = Float3::expr(r.x(), r.y(), p.y()).xzy(); let r = rotate.call(p.xy(), t * 1.89); - p = make_float3(r.x(), r.y(), p.z()); - p = make_float3(p.x().abs() - 0.5, p.y(), p.z().abs() - 0.5) + p = Float3::expr(r.x(), r.y(), p.z()); + p = Float3::expr(p.x().abs() - 0.5, p.y(), p.z().abs() - 0.5) } - Float3Expr::splat(1.0).copysign(p).dot(p) * 0.2 + Expr::::splat(1.0).copysign(p).dot(p) * 0.2 }); - let rm = device.create_callable::, Expr, Expr)-> Expr>( + let rm = device.create_callable::, Expr, Expr) -> Expr>( &|ro, rd, time| { - let t = var!(f32, 0.0); - let col = var!(Float3); - let d = var!(f32); - for_range(0i32..64, |i| { + let t = 0.0_f32.var(); + let col = Var::::zeroed(); + let d = Var::::zeroed(); + for_range(0i32..64, |_i| { let p = ro + rd * *t; *d.get_mut() = map.call(p, time) * 0.5; if_!(d.cmplt(0.02) | d.cmpgt(100.0), { break_() }); @@ -54,21 +52,21 @@ fn main() { *t.get_mut() += *d; }); let col = *col; - make_float4(col.x(), col.y(), col.z(), 1.0 / (100.0 * *d)) + Float4::expr(col.x(), col.y(), col.z(), 1.0 / (100.0 * *d)) }, ); - let clear_kernel = device.create_kernel::,)>(&|img| { + let clear_kernel = device.create_kernel::)>(&|img| { let coord = dispatch_id().xy(); - img.write(coord, make_float4(0.3, 0.4, 0.5, 1.0)); + img.write(coord, Float4::expr(0.3, 0.4, 0.5, 1.0)); }); let main_kernel = device.create_kernel::, f32)>(&|img, time| { let xy = dispatch_id().xy(); let resolution = dispatch_size().xy(); let uv = (xy.float() - resolution.float() * 0.5) / resolution.x().float(); - let r = rotate.call(make_float2(0.0, -50.0), time); - let ro = make_float3(r.x(), r.y(), 0.0).xzy(); + let r = rotate.call(Float2::expr(0.0, -50.0), time); + let ro = Float3::expr(r.x(), r.y(), 0.0).xzy(); let cf = (-ro).normalize(); - let cs = cf.cross(make_float3(0.0, 10.0, 0.0)).normalize(); + let cs = cf.cross(Float3::expr(0.0, 10.0, 0.0)).normalize(); let cu = cf.cross(cs).normalize(); let uuv = ro + cf * 3.0 + uv.x() * cs + uv.y() * cu; let rd = (uuv - ro).normalize(); @@ -76,7 +74,7 @@ fn main() { let color = col.xyz(); let alpha = col.w(); let old = img.read(xy).xyz(); - let accum = color.lerp(old, Float3Expr::splat(alpha)); - img.write(xy, make_float4(accum.x(), accum.y(), accum.z(), 1.0)); + let accum = color.lerp(old, Expr::::splat(alpha)); + img.write(xy, Float4::expr(accum.x(), accum.y(), accum.z(), 1.0)); }); } diff --git a/luisa_compute/examples/vecadd.rs b/luisa_compute/examples/vecadd.rs index f6d9afd..6e5a4a0 100644 --- a/luisa_compute/examples/vecadd.rs +++ b/luisa_compute/examples/vecadd.rs @@ -1,10 +1,10 @@ use std::env::current_exe; +use luisa::prelude::*; use luisa_compute as luisa; fn main() { - use luisa::*; - init_logger(); + luisa::init_logger(); let args: Vec = std::env::args().collect(); assert!( args.len() <= 2, @@ -30,7 +30,7 @@ fn main() { let tid = dispatch_id().x(); let x = buf_x.read(tid); let y = buf_y.read(tid); - let vx = var!(f32, 2.0); // create a local mutable variable + let vx = 2.0_f32.var(); // create a local mutable variable *vx.get_mut() += *vx + x; buf_z.write(tid, vx.load() + y); }); diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs new file mode 100644 index 0000000..dae2d7e --- /dev/null +++ b/luisa_compute/src/lang.rs @@ -0,0 +1,449 @@ +use std::any::Any; +use std::cell::{Cell, RefCell}; +use std::fmt::Debug; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use std::{env, unreachable}; + +use crate::internal_prelude::*; + +use bumpalo::Bump; +use indexmap::IndexMap; + +use crate::runtime::WeakDevice; + +pub mod ir { + pub use luisa_compute_ir::context::register_type; + pub use luisa_compute_ir::ir::*; + pub use luisa_compute_ir::*; +} + +pub use ir::NodeRef; +use ir::{ + new_user_node, BasicBlock, Binding, CArc, CallableModuleRef, Const, CpuCustomOp, Func, + Instruction, IrBuilder, ModulePools, Pooled, Type, TypeOf, UserNodeData, +}; + +pub mod control_flow; +pub mod debug; +pub mod diff; +pub mod functions; +pub mod index; +pub mod maybe_expr; +pub mod ops; +pub mod poly; +pub mod swizzle; +pub mod types; + +#[allow(dead_code)] +pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0); +// prevent node being shared across kernels +// TODO: replace NodeRef with SafeNodeRef +#[derive(Clone, Copy, Debug)] +pub(crate) struct SafeNodeRef { + #[allow(dead_code)] + pub(crate) node: NodeRef, + #[allow(dead_code)] + pub(crate) kernel_id: usize, +} + +pub trait Aggregate: Sized { + fn to_vec_nodes(&self) -> Vec { + let mut nodes = vec![]; + Self::to_nodes(&self, &mut nodes); + nodes + } + fn from_vec_nodes(nodes: Vec) -> Self { + let mut iter = nodes.into_iter(); + let ret = Self::from_nodes(&mut iter); + assert!(iter.next().is_none()); + ret + } + fn to_nodes(&self, nodes: &mut Vec); + fn from_nodes>(iter: &mut I) -> Self; +} + +impl Aggregate for Vec { + fn to_nodes(&self, nodes: &mut Vec) { + let len_node = new_user_node(__module_pools(), nodes.len()); + nodes.push(len_node); + for item in self { + item.to_nodes(nodes); + } + } + + fn from_nodes>(iter: &mut I) -> Self { + let len_node = iter.next().unwrap(); + let len = len_node.unwrap_user_data::(); + let mut ret = Vec::with_capacity(*len); + for _ in 0..*len { + ret.push(T::from_nodes(iter)); + } + ret + } +} + +impl Aggregate for RefCell { + fn to_nodes(&self, nodes: &mut Vec) { + self.borrow().to_nodes(nodes); + } + + fn from_nodes>(iter: &mut I) -> Self { + RefCell::new(T::from_nodes(iter)) + } +} +impl Aggregate for Cell { + fn to_nodes(&self, nodes: &mut Vec) { + self.get().to_nodes(nodes); + } + + fn from_nodes>(iter: &mut I) -> Self { + Cell::new(T::from_nodes(iter)) + } +} +impl Aggregate for Option { + fn to_nodes(&self, nodes: &mut Vec) { + match self { + Some(x) => { + let node = new_user_node(__module_pools(), 1); + nodes.push(node); + x.to_nodes(nodes); + } + None => { + let node = new_user_node(__module_pools(), 0); + nodes.push(node); + } + } + } + + fn from_nodes>(iter: &mut I) -> Self { + let node = iter.next().unwrap(); + let tag = node.unwrap_user_data::(); + match *tag { + 0 => None, + 1 => Some(T::from_nodes(iter)), + _ => unreachable!(), + } + } +} + +pub trait StructInitiaizable: Value { + type Init: Into; +} + +pub trait ToNode { + fn node(&self) -> NodeRef; +} + +pub trait FromNode: ToNode { + fn from_node(node: NodeRef) -> Self; +} + +fn _store(var: &T1, value: &T2) { + let value_nodes = value.to_vec_nodes(); + let self_nodes = var.to_vec_nodes(); + assert_eq!(value_nodes.len(), self_nodes.len()); + __current_scope(|b| { + for (value_node, self_node) in value_nodes.into_iter().zip(self_nodes.into_iter()) { + b.update(self_node, value_node); + } + }) +} + +#[inline(always)] +pub fn __new_user_node(data: T) -> NodeRef { + new_user_node(__module_pools(), data) +} +macro_rules! impl_aggregate_for_tuple { + ()=>{ + impl Aggregate for () { + fn to_nodes(&self, _: &mut Vec) {} + fn from_nodes>(_: &mut I) -> Self{} + } + }; + ($first:ident $($rest:ident) *) => { + impl<$first:Aggregate, $($rest: Aggregate),*> Aggregate for ($first, $($rest,)*) { + #[allow(non_snake_case)] + fn to_nodes(&self, nodes: &mut Vec) { + let ($first, $($rest,)*) = self; + $first.to_nodes(nodes); + $($rest.to_nodes(nodes);)* + } + #[allow(non_snake_case)] + fn from_nodes>(iter: &mut I) -> Self { + let $first = Aggregate::from_nodes(iter); + $(let $rest = Aggregate::from_nodes(iter);)* + ($first, $($rest,)*) + } + } + impl_aggregate_for_tuple!($($rest)*); + }; + +} +impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +pub(crate) struct Recorder { + pub(crate) scopes: Vec, + pub(crate) kernel_id: Option, + pub(crate) lock: bool, + pub(crate) captured_buffer: IndexMap)>, + pub(crate) cpu_custom_ops: IndexMap)>, + pub(crate) callables: IndexMap, + pub(crate) shared: Vec, + pub(crate) device: Option, + pub(crate) block_size: Option<[u32; 3]>, + pub(crate) building_kernel: bool, + pub(crate) pools: Option>, + pub(crate) arena: Bump, + pub(crate) callable_ret_type: Option>, +} + +impl Recorder { + pub(crate) fn reset(&mut self) { + self.scopes.clear(); + self.captured_buffer.clear(); + self.cpu_custom_ops.clear(); + self.callables.clear(); + self.lock = false; + self.device = None; + self.block_size = None; + self.arena.reset(); + self.shared.clear(); + self.kernel_id = None; + self.callable_ret_type = None; + } + pub(crate) fn new() -> Self { + Recorder { + scopes: vec![], + lock: false, + captured_buffer: IndexMap::new(), + cpu_custom_ops: IndexMap::new(), + callables: IndexMap::new(), + shared: vec![], + device: None, + block_size: None, + pools: None, + arena: Bump::new(), + building_kernel: false, + kernel_id: None, + callable_ret_type: None, + } + } +} +thread_local! { + pub(crate) static RECORDER: RefCell = RefCell::new(Recorder::new()); +} + +// Don't call this function directly unless you know what you are doing +pub fn __current_scope R, R>(f: F) -> R { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!(r.lock, "__current_scope must be called within a kernel"); + let s = &mut r.scopes; + f(s.last_mut().unwrap()) + }) +} + +pub(crate) fn __invoke_callable(callable: &CallableModuleRef, args: &[NodeRef]) -> NodeRef { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let id = CArc::as_ptr(&callable.0) as u64; + if let Some(c) = r.callables.get(&id) { + assert_eq!(CArc::as_ptr(&c.0), CArc::as_ptr(&callable.0)); + } else { + r.callables.insert(id, callable.clone()); + } + }); + __current_scope(|b| { + b.call( + Func::Callable(callable.clone()), + args, + callable.0.ret_type.clone(), + ) + }) +} + +pub(crate) fn __check_node_type(a: NodeRef, b: NodeRef) -> bool { + if !ir::context::is_type_equal(a.type_(), b.type_()) { + return false; + } + match (a.get().instruction.as_ref(), b.get().instruction.as_ref()) { + (Instruction::Buffer, Instruction::Buffer) => true, + (Instruction::Texture2D, Instruction::Texture2D) => true, + (Instruction::Texture3D, Instruction::Texture3D) => true, + (Instruction::Bindless, Instruction::Bindless) => true, + (Instruction::Accel, Instruction::Accel) => true, + (Instruction::Uniform, Instruction::Uniform) => true, + (Instruction::Local { .. }, Instruction::Local { .. }) => true, + (Instruction::Argument { by_value: true }, _) => b.get().instruction.has_value(), + (Instruction::Argument { by_value: false }, _) => b.is_lvalue(), + _ => false, + } +} + +pub(crate) fn __check_callable(callable: &CallableModuleRef, args: &[NodeRef]) -> bool { + assert_eq!(callable.0.args.len(), args.len()); + for i in 0..args.len() { + if !__check_node_type(callable.0.args[i], args[i]) { + return false; + } + } + true +} + +// Don't call this function directly unless you know what you are doing +pub fn __pop_scope() -> Pooled { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let s = &mut r.scopes; + s.pop().unwrap().finish() + }) +} + +pub fn __module_pools() -> &'static CArc { + RECORDER.with(|r| { + let r = r.borrow(); + assert!(r.lock, "__module_pools must be called within a kernel"); + let pool = r.pools.as_ref().unwrap(); + unsafe { std::mem::transmute(pool) } + }) +} +// pub fn __load(node: NodeRef) -> Expr { +// __current_scope(|b| { +// let node = b.load(node); +// Expr::::from_node(node) +// }) +// } +// pub fn __store(var:NodeRef, value:NodeRef) { +// let inst = &var.get().instruction; +// } + +pub fn __extract(node: NodeRef, index: usize) -> NodeRef { + let inst = &node.get().instruction; + __current_scope(|b| { + let i = b.const_(Const::Int32(index as i32)); + let op = match inst.as_ref() { + Instruction::Local { .. } => Func::GetElementPtr, + Instruction::Argument { by_value } => { + if *by_value { + Func::ExtractElement + } else { + Func::GetElementPtr + } + } + Instruction::Call(f, args) => match f { + Func::AtomicRef => { + let mut indices = args.to_vec(); + indices.push(i); + return b.call(Func::AtomicRef, &indices, ::type_()); + } + _ => Func::ExtractElement, + }, + _ => Func::ExtractElement, + }; + let node = b.call(op, &[node, i], ::type_()); + node + }) +} + +pub fn __insert(node: NodeRef, index: usize, value: NodeRef) -> NodeRef { + let inst = &node.get().instruction; + __current_scope(|b| { + let i = b.const_(Const::Int32(index as i32)); + let op = match inst.as_ref() { + Instruction::Local { .. } => panic!("Can't insert into local variable"), + _ => Func::InsertElement, + }; + let node = b.call(op, &[node, value, i], ::type_()); + node + }) +} + +pub fn __compose(nodes: &[NodeRef]) -> NodeRef { + let ty = ::type_(); + match ty.as_ref() { + Type::Struct(st) => { + assert_eq!(st.fields.as_ref().len(), nodes.len()); + __current_scope(|b| b.call(Func::Struct, nodes, ::type_())) + } + Type::Primitive(_) => panic!("Can't compose primitive type"), + Type::Vector(vt) => { + let length = vt.length; + let func = match length { + 2 => Func::Vec2, + 3 => Func::Vec3, + 4 => Func::Vec4, + _ => panic!("Can't compose vector with length {}", length), + }; + __current_scope(|b| b.call(func, nodes, ::type_())) + } + Type::Matrix(vt) => { + let length = vt.dimension; + let func = match length { + 2 => Func::Mat2, + 3 => Func::Mat3, + 4 => Func::Mat4, + _ => panic!("Can't compose vector with length {}", length), + }; + __current_scope(|b| b.call(func, nodes, ::type_())) + } + _ => todo!(), + } +} +#[macro_export] +macro_rules! struct_ { + ($t:ty { $($it:ident : $value:expr), * $(,)?}) =>{ + { + type Init = <$t as $crate::lang::StructInitiaizable>::Init; + let init = Init { $($it : $value), * }; + type Expr = <$t as $crate::lang::types::Value>::Expr; + let e:Expr = init.into(); + e + } + } +} + +pub const fn packed_size() -> usize { + (std::mem::size_of::() + 3) / 4 +} + +pub fn pack_to(expr: E, buffer: &B, index: impl Into>) +where + E: ExprProxy, + B: IndexWrite, +{ + let index = index.into(); + __current_scope(|b| { + b.call( + Func::Pack, + &[expr.node(), buffer.node(), index.node()], + Type::void(), + ); + }); +} + +pub fn unpack_from( + buffer: &impl IndexWrite, + index: impl Into>, +) -> Expr +where + T: Value, +{ + let index = index.into(); + Expr::::from_node(__current_scope(|b| { + b.call( + Func::Unpack, + &[buffer.node(), index.node()], + ::type_(), + ) + })) +} + +pub(crate) fn need_runtime_check() -> bool { + cfg!(debug_assertions) + || match env::var("LUISA_DEBUG") { + Ok(s) => s == "full" || s == "1", + Err(_) => false, + } + || debug::__env_need_backtrace() +} diff --git a/luisa_compute/src/lang/control_flow.rs b/luisa_compute/src/lang/control_flow.rs new file mode 100644 index 0000000..fa2da05 --- /dev/null +++ b/luisa_compute/src/lang/control_flow.rs @@ -0,0 +1,449 @@ +use std::ffi::CString; + +use crate::internal_prelude::*; +use ir::SwitchCase; + +/** + * If you want rustfmt to format your code, use if_!(cond, { .. }, { .. }) or if_!(cond, { .. }, else, {...}) + * instead of if_!(cond, { .. }, else {...}). + * + */ +#[macro_export] +macro_rules! if_ { + ($cond:expr, $then:block, else $else_:block) => { + <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( + $cond, + || $then, + || $else_, + ) + }; + ($cond:expr, $then:block, else, $else_:block) => { + <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( + $cond, + || $then, + || $else_, + ) + }; + ($cond:expr, $then:block, $else_:block) => { + <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( + $cond, + || $then, + || $else_, + ) + }; + ($cond:expr, $then:block) => { + <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( + $cond, + || $then, + || {}, + ) + }; +} +#[macro_export] +macro_rules! while_ { + ($cond:expr,$body:block) => { + <_ as $crate::lang::maybe_expr::BoolWhileMaybeExpr>::while_loop(|| $cond, || $body) + }; +} +#[macro_export] +macro_rules! loop_ { + ($body:block) => { + $crate::while_!(true.expr(), $body) + }; +} + +#[inline] +pub fn break_() { + __current_scope(|b| { + b.break_(); + }); +} + +#[inline] +pub fn continue_() { + __current_scope(|b| { + b.continue_(); + }); +} + +pub fn return_v(v: T) { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + if r.callable_ret_type.is_none() { + r.callable_ret_type = Some(v.node().type_().clone()); + } else { + assert!( + luisa_compute_ir::context::is_type_equal( + r.callable_ret_type.as_ref().unwrap(), + v.node().type_() + ), + "return type mismatch" + ); + } + }); + __current_scope(|b| { + b.return_(v.node()); + }); +} +pub fn return_() { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + if r.callable_ret_type.is_none() { + r.callable_ret_type = Some(Type::void()); + } else { + assert!(luisa_compute_ir::context::is_type_equal( + r.callable_ret_type.as_ref().unwrap(), + &Type::void() + )); + } + }); + __current_scope(|b| { + b.return_(INVALID_REF); + }); +} + +pub fn if_then_else( + cond: Expr, + then: impl FnOnce() -> R, + else_: impl FnOnce() -> R, +) -> R { + let cond = cond.node(); + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + let then = then(); + let then_block = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + let then_block = s.pop().unwrap().finish(); + s.push(IrBuilder::new(pools)); + then_block + }); + let else_ = else_(); + let else_block = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let s = &mut r.scopes; + s.pop().unwrap().finish() + }); + let then_nodes = then.to_vec_nodes(); + let else_nodes = else_.to_vec_nodes(); + __current_scope(|b| { + b.if_(cond, then_block, else_block); + }); + assert_eq!(then_nodes.len(), else_nodes.len()); + let phis = __current_scope(|b| { + then_nodes + .iter() + .zip(else_nodes.iter()) + .map(|(then, else_)| { + let incomings = vec![ + PhiIncoming { + value: *then, + block: then_block, + }, + PhiIncoming { + value: *else_, + block: else_block, + }, + ]; + assert_eq!(then.type_(), else_.type_()); + let phi = b.phi(&incomings, then.type_().clone()); + phi + }) + .collect::>() + }); + R::from_vec_nodes(phis) +} + +pub fn select(mask: Expr, a: A, b: A) -> A { + let a_nodes = a.to_vec_nodes(); + let b_nodes = b.to_vec_nodes(); + assert_eq!(a_nodes.len(), b_nodes.len()); + let mut ret = vec![]; + __current_scope(|b| { + for (a_node, b_node) in a_nodes.into_iter().zip(b_nodes.into_iter()) { + assert_eq!(a_node.type_(), b_node.type_()); + assert!(!a_node.is_local(), "cannot select local variables"); + assert!(!b_node.is_local(), "cannot select local variables"); + if a_node.is_user_data() || b_node.is_user_data() { + assert!( + a_node.is_user_data() && b_node.is_user_data(), + "cannot select user data and non-user data" + ); + let a_data = a_node.get_user_data(); + let b_data = b_node.get_user_data(); + if a_data != b_data { + panic!("cannot select different user data"); + } + ret.push(a_node); + } else { + ret.push(b.call( + Func::Select, + &[mask.node(), a_node, b_node], + a_node.type_().clone(), + )); + } + } + }); + A::from_vec_nodes(ret) +} + +pub fn generic_loop( + mut cond: impl FnMut() -> Expr, + mut body: impl FnMut(), + mut update: impl FnMut(), +) { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + let cond_v = cond().node(); + let prepare = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + let prepare = s.pop().unwrap().finish(); + s.push(IrBuilder::new(pools)); + prepare + }); + body(); + let body = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + let body = s.pop().unwrap().finish(); + s.push(IrBuilder::new(pools)); + body + }); + update(); + let update = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let s = &mut r.scopes; + s.pop().unwrap().finish() + }); + __current_scope(|b| { + b.generic_loop(prepare, cond_v, body, update); + }); +} + +pub trait ForLoopRange { + type Element: Value; + fn start(&self) -> NodeRef; + fn end(&self) -> NodeRef; + fn end_inclusive(&self) -> bool; +} +macro_rules! impl_range { + ($t:ty) => { + impl ForLoopRange for std::ops::RangeInclusive<$t> { + type Element = $t; + fn start(&self) -> NodeRef { + (*self.start()).expr().node() + } + fn end(&self) -> NodeRef { + (*self.end()).expr().node() + } + fn end_inclusive(&self) -> bool { + true + } + } + impl ForLoopRange for std::ops::RangeInclusive> { + type Element = $t; + fn start(&self) -> NodeRef { + self.start().node() + } + fn end(&self) -> NodeRef { + self.end().node() + } + fn end_inclusive(&self) -> bool { + true + } + } + impl ForLoopRange for std::ops::Range<$t> { + type Element = $t; + fn start(&self) -> NodeRef { + (self.start).expr().node() + } + fn end(&self) -> NodeRef { + (self.end).expr().node() + } + fn end_inclusive(&self) -> bool { + false + } + } + impl ForLoopRange for std::ops::Range> { + type Element = $t; + fn start(&self) -> NodeRef { + self.start.node() + } + fn end(&self) -> NodeRef { + self.end.node() + } + fn end_inclusive(&self) -> bool { + false + } + } + }; +} +impl_range!(i32); +impl_range!(i64); +impl_range!(u32); +impl_range!(u64); + +pub fn for_range(r: R, body: impl Fn(Expr)) { + let start = r.start(); + let end = r.end(); + let inc = |v: NodeRef| { + __current_scope(|b| { + let one = b.const_(Const::One(v.type_().clone())); + b.call(Func::Add, &[v, one], v.type_().clone()) + }) + }; + let i = __current_scope(|b| b.local(start)); + generic_loop( + || { + __current_scope(|b| { + let i = b.call(Func::Load, &[i], i.type_().clone()); + Expr::::from_node(b.call( + if r.end_inclusive() { + Func::Le + } else { + Func::Lt + }, + &[i, end], + ::type_(), + )) + }) + }, + move || { + let i = __current_scope(|b| b.call(Func::Load, &[i], i.type_().clone())); + body(Expr::::from_node(i)); + }, + || { + let i_old = __current_scope(|b| b.call(Func::Load, &[i], i.type_().clone())); + let i_new = inc(i_old); + __current_scope(|b| b.update(i, i_new)); + }, + ) +} + +pub struct SwitchBuilder { + cases: Vec<(i32, Pooled, Vec)>, + default: Option<(Pooled, Vec)>, + value: NodeRef, + _marker: PhantomData, + depth: usize, +} + +pub fn switch(node: Expr) -> SwitchBuilder { + SwitchBuilder::new(node) +} + +impl SwitchBuilder { + pub fn new(node: Expr) -> Self { + SwitchBuilder { + cases: vec![], + default: None, + value: node.node(), + _marker: PhantomData, + depth: RECORDER.with(|r| r.borrow().scopes.len()), + } + } + pub fn case(mut self, value: i32, then: impl Fn() -> R) -> Self { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + assert_eq!(s.len(), self.depth); + s.push(IrBuilder::new(pools)); + }); + let then = then(); + let block = __pop_scope(); + self.cases.push((value, block, then.to_vec_nodes())); + self + } + pub fn default(mut self, then: impl Fn() -> R) -> Self { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + assert_eq!(s.len(), self.depth); + s.push(IrBuilder::new(pools)); + }); + let then = then(); + let block = __pop_scope(); + self.default = Some((block, then.to_vec_nodes())); + self + } + pub fn finish(self) -> R { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let s = &mut r.scopes; + assert_eq!(s.len(), self.depth); + }); + let cases = self + .cases + .iter() + .map(|(v, b, _)| SwitchCase { + value: *v, + block: *b, + }) + .collect::>(); + let case_phis = self + .cases + .iter() + .map(|(_, _, nodes)| nodes.clone()) + .collect::>(); + let phi_count = case_phis[0].len(); + let mut default_nodes = vec![]; + let default_block = if self.default.is_none() { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + assert_eq!(s.len(), self.depth); + s.push(IrBuilder::new(pools)); + }); + for i in 0..phi_count { + let msg = CString::new("unreachable code in switch statement!").unwrap(); + let default_node = __current_scope(|b| { + b.call( + Func::Unreachable(CBoxedSlice::from(msg)), + &[], + case_phis[0][i].type_().clone(), + ) + }); + default_nodes.push(default_node); + } + __pop_scope() + } else { + default_nodes = self.default.as_ref().unwrap().1.clone(); + self.default.as_ref().unwrap().0 + }; + __current_scope(|b| { + b.switch(self.value, &cases, default_block); + }); + let mut phis = vec![]; + for i in 0..phi_count { + let mut incomings = vec![]; + for (j, nodes) in case_phis.iter().enumerate() { + incomings.push(PhiIncoming { + value: nodes[i], + block: self.cases[j].1, + }); + } + incomings.push(PhiIncoming { + value: default_nodes[i], + block: default_block, + }); + let phi = __current_scope(|b| b.phi(&incomings, case_phis[0][i].type_().clone())); + phis.push(phi); + } + R::from_vec_nodes(phis) + } +} diff --git a/luisa_compute/src/lang/debug.rs b/luisa_compute/src/lang/debug.rs new file mode 100644 index 0000000..1f2c362 --- /dev/null +++ b/luisa_compute/src/lang/debug.rs @@ -0,0 +1,217 @@ +use ir::CpuCustomOp; +use std::ffi::CString; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::internal_prelude::*; + +pub struct CpuFn { + op: CArc, + _marker: PhantomData, +} + +/* +Interestingly, Box::into_raw(Box) does not give a valid pointer. +*/ +struct ClosureContainer { + f: Arc, +} + +impl CpuFn { + pub fn new(f: F) -> Self { + let f_ptr = Box::into_raw(Box::new(ClosureContainer:: { f: Arc::new(f) })); + let op = CpuCustomOp { + data: f_ptr as *mut u8, + func: _trampoline::, + destructor: _drop::, + arg_type: T::type_(), + }; + Self { + op: CArc::new(op), + _marker: PhantomData, + } + } + pub fn call(&self, arg: impl ExprProxy) -> Expr { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!(r.lock); + assert_eq!( + r.device + .as_ref() + .unwrap() + .upgrade() + .unwrap() + .inner + .query("device_name") + .unwrap(), + "cpu", + "CpuFn can only be used in cpu backend" + ); + let addr = CArc::as_ptr(&self.op) as u64; + if let Some((_, op)) = r.cpu_custom_ops.get(&addr) { + assert_eq!(CArc::as_ptr(op), CArc::as_ptr(&self.op)); + } else { + let i = r.cpu_custom_ops.len(); + r.cpu_custom_ops.insert(addr, (i, self.op.clone())); + } + }); + Expr::::from_node(__current_scope(|b| { + b.call( + Func::CpuCustomOp(self.op.clone()), + &[arg.node()], + T::type_(), + ) + })) + } +} + +extern "C" fn _trampoline(data: *mut u8, args: *mut u8) { + unsafe { + let container = &*(data as *const ClosureContainer); + let f = &container.f; + let args = &mut *(args as *mut T); + f(args); + } +} + +extern "C" fn _drop(data: *mut u8) { + unsafe { + let _ = Box::from_raw(data as *mut T); + } +} + +#[macro_export] +macro_rules! cpu_dbg { + ($arg:expr) => {{ + $crate::lang::debug::__cpu_dbg($arg, file!(), line!()) + }}; +} +#[macro_export] +macro_rules! lc_dbg { + ($arg:expr) => {{ + $crate::lang::debug::__cpu_dbg($arg, file!(), line!()) + }}; +} +#[macro_export] +macro_rules! lc_unreachable { + () => { + $crate::lang::debug::__unreachable(file!(), line!(), column!()) + }; +} +#[macro_export] +macro_rules! lc_assert { + ($arg:expr) => { + $crate::lang::debug::__assert($arg, stringify!($arg), file!(), line!(), column!()) + }; + ($arg:expr, $msg:expr) => { + $crate::lang::debug::__assert($arg, $msg, file!(), line!(), column!()) + }; +} +pub fn __cpu_dbg(arg: T, file: &'static str, line: u32) +where + T::Value: Debug, +{ + if !is_cpu_backend() { + return; + } + let f = CpuFn::new(move |x: &mut T::Value| { + println!("[{}:{}] {:?}", file, line, x); + }); + let _ = f.call(arg); +} + +pub fn is_cpu_backend() -> bool { + RECORDER.with(|r| { + let r = r.borrow(); + if r.device.is_none() { + return false; + } + r.device + .as_ref() + .unwrap() + .upgrade() + .unwrap() + .inner + .query("device_name") + .map(|s| s == "cpu") + .unwrap_or(false) + }) +} + +pub fn __env_need_backtrace() -> bool { + match std::env::var("LUISA_BACKTRACE") { + Ok(s) => s == "1" || s == "ON", + Err(_) => false, + } +} + +pub fn __unreachable(file: &str, line: u32, col: u32) { + let path = std::path::Path::new(file); + let pretty_filename: String; + if path.exists() { + pretty_filename = std::fs::canonicalize(path) + .unwrap() + .to_str() + .unwrap() + .to_string(); + } else { + pretty_filename = file.to_string(); + } + let msg = if is_cpu_backend() && __env_need_backtrace() { + let backtrace = get_backtrace(); + format!( + "unreachable code at {}:{}:{} \nbacktrace: {}", + pretty_filename, line, col, backtrace + ) + } else { + format!( + "unreachable code at {}:{}:{} \n", + pretty_filename, line, col + ) + }; + __current_scope(|b| { + b.call( + Func::Unreachable(CBoxedSlice::new( + CString::new(msg).unwrap().into_bytes_with_nul(), + )), + &[], + Type::void(), + ); + }); +} + +pub fn __assert(cond: impl Into>, msg: &str, file: &str, line: u32, col: u32) { + let cond = cond.into(); + let path = std::path::Path::new(file); + let pretty_filename: String; + if path.exists() { + pretty_filename = std::fs::canonicalize(path) + .unwrap() + .to_str() + .unwrap() + .to_string(); + } else { + pretty_filename = file.to_string(); + } + let msg = if is_cpu_backend() && __env_need_backtrace() { + let backtrace = get_backtrace(); + format!( + "assertion failed: {} at {}:{}:{} \nbacktrace: {}", + msg, pretty_filename, line, col, backtrace + ) + } else { + format!( + "assertion failed: {} at {}:{}:{} \n", + msg, pretty_filename, line, col + ) + }; + __current_scope(|b| { + b.call( + Func::Assert(CBoxedSlice::new( + CString::new(msg).unwrap().into_bytes_with_nul(), + )), + &[cond.node()], + Type::void(), + ); + }); +} diff --git a/luisa_compute/src/lang/diff.rs b/luisa_compute/src/lang/diff.rs new file mode 100644 index 0000000..cb37fbb --- /dev/null +++ b/luisa_compute/src/lang/diff.rs @@ -0,0 +1,220 @@ +use std::cell::RefCell; + +use crate::internal_prelude::*; + +struct AdContext { + started: bool, + backward_called: bool, + is_forward_mode: bool, + n_forward_grads: usize, + // forward: Option>, +} + +impl AdContext { + fn new_rev() -> Self { + Self { + started: false, + backward_called: false, + is_forward_mode: false, + n_forward_grads: 0, + } + } + fn new_fwd(n: usize) -> Self { + Self { + started: false, + backward_called: false, + is_forward_mode: true, + n_forward_grads: n, + } + } + fn reset(&mut self) { + self.started = false; + } +} +thread_local! { + static AD_CONTEXT:RefCell = RefCell::new(AdContext::new_rev()); +} +pub fn requires_grad(var: impl ExprProxy) { + AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert!(c.started, "autodiff section is not started"); + assert!( + !c.is_forward_mode, + "requires_grad() is called in forward mode" + ); + assert!(!c.backward_called, "backward is already called"); + }); + __current_scope(|b| { + b.call(Func::RequiresGradient, &[var.node()], Type::void()); + }); +} + +pub fn backward(out: T) { + backward_with_grad( + out, + FromNode::from_node(__current_scope(|b| { + let one = new_node( + b.pools(), + Node::new( + CArc::new(Instruction::Const(Const::One(::type_()))), + ::type_(), + ), + ); + b.append(one); + one + })), + ); +} + +pub fn backward_with_grad(out: T, grad: T) { + AD_CONTEXT.with(|c| { + let mut c = c.borrow_mut(); + assert!(c.started, "autodiff section is not started"); + assert!(!c.is_forward_mode, "backward() is called in forward mode"); + assert!(!c.backward_called, "backward is already called"); + c.backward_called = true; + }); + let out = out.node(); + let grad = grad.node(); + __current_scope(|b| { + b.call(Func::GradientMarker, &[out, grad], Type::void()); + b.call(Func::Backward, &[], Type::void()); + }); +} + +/// Gradient of a value in *Reverse mode* AD +pub fn gradient(var: T) -> T { + AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert!(c.started, "autodiff section is not started"); + assert!(!c.is_forward_mode, "gradient() is called in forward mode"); + assert!(c.backward_called, "backward is not called"); + }); + T::from_node(__current_scope(|b| { + b.call(Func::Gradient, &[var.node()], var.node().type_().clone()) + })) +} +/// Gradient of a value in *Reverse mode* AD +pub fn grad(var: T) -> T { + gradient(var) +} + +// pub fn detach(body: impl FnOnce() -> R) -> R { +// RECORDER.with(|r| { +// let mut r = r.borrow_mut(); +// let s = &mut r.scopes; +// s.push(IrBuilder::new()); +// }); +// let ret = body(); +// let fwd = pop_scope(); +// __current_scope(|b| { +// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)), Type::void())); +// b.append(node); +// }); +// let nodes = ret.to_vec_nodes(); +// let nodes: Vec<_> = nodes +// .iter() +// .map(|n| __current_scope(|b| b.call(Func::Detach, &[*n], n.type_()))) +// .collect(); +// R::from_vec_nodes(nodes) +// } +pub fn detach(v: T) -> T { + let v = v.node(); + let node = __current_scope(|b| b.call(Func::Detach, &[v], v.type_().clone())); + T::from_node(node) +} + +/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input variable +pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { + AD_CONTEXT.with(|c| { + let mut c = c.borrow_mut(); + assert!(!c.started, "autodiff section already started"); + *c = AdContext::new_fwd(n_grads); + c.started = true; + }); + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + body(); + let n_grads = AD_CONTEXT.with(|c| { + let mut c = c.borrow_mut(); + let n_grads = c.n_forward_grads; + c.reset(); + n_grads + }); + let body = __pop_scope(); + __current_scope(|b| { + b.fwd_ad_scope(body, n_grads); + }); +} + +/// Propagate N gradients w.r.t to input variable using *Forward mode* AD +pub fn propagate_gradient(v: T, grads: &[T]) { + AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert_eq!(grads.len(), c.n_forward_grads); + assert!(c.started, "autodiff section is not started"); + assert!( + c.is_forward_mode, + "propagate_gradient() is called in backward mode" + ); + }); + __current_scope(|b| { + let mut nodes = vec![v.node()]; + nodes.extend(grads.iter().map(|g| g.node())); + b.call(Func::PropagateGrad, &nodes, Type::void()); + }); +} + +pub fn output_gradients(v: T) -> Vec { + let n = AD_CONTEXT.with(|c| { + let c = c.borrow(); + assert!(c.started, "autodiff section is not started"); + assert!( + c.is_forward_mode, + "output_gradients() is called in backward mode" + ); + c.n_forward_grads + }); + __current_scope(|b| { + let mut grads = vec![]; + for i in 0..n { + let idx = b.const_(Const::Int32(i as i32)); + grads.push(T::from_node(b.call( + Func::OutputGrad, + &[v.node(), idx], + v.node().type_().clone(), + ))); + } + grads + }) +} + +pub fn autodiff(body: impl Fn()) { + AD_CONTEXT.with(|c| { + let mut c = c.borrow_mut(); + assert!(!c.started, "autodiff section is already started"); + *c = AdContext::new_rev(); + c.started = true; + }); + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + let pools = r.pools.clone().unwrap(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + body(); + AD_CONTEXT.with(|c| { + let mut c = c.borrow_mut(); + assert!(c.started, "autodiff section is not started"); + assert!(c.backward_called, "backward is not called"); + c.reset(); + }); + let body = __pop_scope(); + __current_scope(|b| { + b.ad_scope(body); + }); +} diff --git a/luisa_compute/src/lang/functions.rs b/luisa_compute/src/lang/functions.rs new file mode 100644 index 0000000..23be6fa --- /dev/null +++ b/luisa_compute/src/lang/functions.rs @@ -0,0 +1,212 @@ +use crate::internal_prelude::*; + +pub fn thread_id() -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::ThreadId, &[], Uint3::type_()) + })) +} + +pub fn block_id() -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::BlockId, &[], Uint3::type_()) + })) +} + +pub fn dispatch_id() -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::DispatchId, &[], Uint3::type_()) + })) +} + +pub fn dispatch_size() -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::DispatchSize, &[], Uint3::type_()) + })) +} + +fn check_block_size_for_cpu() { + RECORDER.with(|r| { + let r = r.borrow(); + assert!( + r.block_size.is_some(), + "CPU backend only support block operations on block size 1" + ); + let size = r.block_size.unwrap(); + assert_eq!( + size, + [1, 1, 1], + "CPU backend only support block operations on block size 1" + ); + }); +} +pub fn sync_block() { + if is_cpu_backend() { + check_block_size_for_cpu(); + return; + } + __current_scope(|b| { + b.call(Func::SynchronizeBlock, &[], Type::void()); + }) +} + +pub fn warp_is_first_active_lane() -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpIsFirstActiveLane, &[], Expr::::type_()) + })) +} +pub fn warp_active_all_equal(v: impl ScalarOrVector) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call( + Func::WarpActiveAllEqual, + &[v.node()], + ::type_(), + ) + })) +} +pub fn warp_active_bit_and, E: IntVarTrait>(v: T) -> T { + T::from_node(__current_scope(|b| { + b.call( + Func::WarpActiveBitAnd, + &[v.node()], + ::type_(), + ) + })) +} + +pub fn warp_active_bit_or, E: IntVarTrait>(v: T) -> T { + T::from_node(__current_scope(|b| { + b.call( + Func::WarpActiveBitOr, + &[v.node()], + ::type_(), + ) + })) +} + +pub fn warp_active_bit_xor, E: IntVarTrait>(v: T) -> T { + T::from_node(__current_scope(|b| { + b.call( + Func::WarpActiveBitXor, + &[v.node()], + ::type_(), + ) + })) +} + +pub fn warp_active_count_bits(v: impl Into>) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call( + Func::WarpActiveCountBits, + &[v.into().node()], + ::type_(), + ) + })) +} +pub fn warp_active_max(v: T) -> T::Element { + ::from_node(__current_scope(|b| { + b.call(Func::WarpActiveMax, &[v.node()], ::type_()) + })) +} +pub fn warp_active_min(v: T) -> T::Element { + ::from_node(__current_scope(|b| { + b.call(Func::WarpActiveMin, &[v.node()], ::type_()) + })) +} +pub fn warp_active_product(v: T) -> T::Element { + ::from_node(__current_scope(|b| { + b.call( + Func::WarpActiveProduct, + &[v.node()], + ::type_(), + ) + })) +} +pub fn warp_active_sum(v: T) -> T::Element { + ::from_node(__current_scope(|b| { + b.call(Func::WarpActiveSum, &[v.node()], ::type_()) + })) +} +pub fn warp_active_all(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveAll, &[v.node()], ::type_()) + })) +} +pub fn warp_active_any(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveAny, &[v.node()], ::type_()) + })) +} +pub fn warp_active_bit_mask() -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveBitMask, &[], ::type_()) + })) +} +pub fn warp_prefix_count_bits(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call( + Func::WarpPrefixCountBits, + &[v.node()], + ::type_(), + ) + })) +} +pub fn warp_prefix_sum_exclusive(v: T) -> T { + T::from_node(__current_scope(|b| { + b.call(Func::WarpPrefixSum, &[v.node()], v.node().type_().clone()) + })) +} +pub fn warp_prefix_product_exclusive(v: T) -> T { + T::from_node(__current_scope(|b| { + b.call( + Func::WarpPrefixProduct, + &[v.node()], + v.node().type_().clone(), + ) + })) +} +pub fn warp_read_lane_at(v: T, index: impl Into>) -> T { + let index = index.into(); + T::from_node(__current_scope(|b| { + b.call( + Func::WarpReadLaneAt, + &[v.node(), index.node()], + v.node().type_().clone(), + ) + })) +} +pub fn warp_read_first_active_lane(v: T) -> T { + T::from_node(__current_scope(|b| { + b.call( + Func::WarpReadFirstLane, + &[v.node()], + v.node().type_().clone(), + ) + })) +} +pub fn set_block_size(size: [u32; 3]) { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!( + r.building_kernel, + "set_block_size cannot be called in callable!" + ); + assert!(r.block_size.is_none(), "Block size already set"); + + r.block_size = Some(size); + }); +} + +pub fn block_size() -> Expr { + RECORDER.with(|r| { + let r = r.borrow(); + let s = r.block_size.unwrap_or_else(|| panic!("Block size not set")); + Uint3::new(s[0], s[1], s[2]).expr() + }) +} + +pub unsafe fn bitcast(expr: Expr) -> Expr { + assert_eq!(std::mem::size_of::(), std::mem::size_of::()); + Expr::::from_node(__current_scope(|b| { + b.call(Func::Bitcast, &[expr.node()], ::type_()) + })) +} diff --git a/luisa_compute/src/lang/index.rs b/luisa_compute/src/lang/index.rs new file mode 100644 index 0000000..31d84d3 --- /dev/null +++ b/luisa_compute/src/lang/index.rs @@ -0,0 +1,44 @@ +use crate::internal_prelude::*; + +pub trait IntoIndex { + fn to_u64(&self) -> Expr; +} +impl IntoIndex for i32 { + fn to_u64(&self) -> Expr { + (*self as u64).expr() + } +} +impl IntoIndex for i64 { + fn to_u64(&self) -> Expr { + (*self as u64).expr() + } +} +impl IntoIndex for u32 { + fn to_u64(&self) -> Expr { + (*self as u64).expr() + } +} +impl IntoIndex for u64 { + fn to_u64(&self) -> Expr { + (*self).expr() + } +} +impl IntoIndex for Expr { + fn to_u64(&self) -> Expr { + self.ulong() + } +} +impl IntoIndex for Expr { + fn to_u64(&self) -> Expr { + *self + } +} + +pub trait IndexRead: ToNode { + type Element: Value; + fn read(&self, i: I) -> Expr; +} + +pub trait IndexWrite: IndexRead { + fn write>>(&self, i: I, value: V); +} diff --git a/luisa_compute/src/lang/maybe_expr.rs b/luisa_compute/src/lang/maybe_expr.rs new file mode 100644 index 0000000..2654889 --- /dev/null +++ b/luisa_compute/src/lang/maybe_expr.rs @@ -0,0 +1,256 @@ +//! The purpose of this module is to provide traits to represent things that may +//! either be an expression or a normal value. This is necessary for making the +//! trace macro work for both types of value. + +use super::control_flow::{generic_loop, if_then_else}; +use super::types::core::*; +use crate::internal_prelude::*; + +pub trait BoolIfElseMaybeExpr { + fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R; +} +impl BoolIfElseMaybeExpr for bool { + fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R { + if self { + then() + } else { + else_() + } + } +} +impl BoolIfElseMaybeExpr for Bool { + fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R { + if_then_else(self, then, else_) + } +} + +pub trait BoolIfMaybeExpr { + fn if_then(self, then: impl FnOnce()); +} +impl BoolIfMaybeExpr for bool { + fn if_then(self, then: impl FnOnce()) { + if self { + then() + } + } +} +impl BoolIfMaybeExpr for Bool { + fn if_then(self, then: impl FnOnce()) { + if_then_else(self, then, || {}) + } +} + +pub trait BoolWhileMaybeExpr { + fn while_loop(this: impl FnMut() -> Self, body: impl FnMut()); +} +impl BoolWhileMaybeExpr for bool { + fn while_loop(mut this: impl FnMut() -> Self, mut body: impl FnMut()) { + while this() { + body() + } + } +} +impl BoolWhileMaybeExpr for Bool { + fn while_loop(this: impl FnMut() -> Self, body: impl FnMut()) { + generic_loop(this, body, || {}); + } +} + +// TODO: Support lazy expressions if that isn't done already? +pub trait BoolLazyOpsMaybeExpr { + type Ret; + fn and(self, other: impl FnOnce() -> R) -> Self::Ret; + fn or(self, other: impl FnOnce() -> R) -> Self::Ret; +} +impl BoolLazyOpsMaybeExpr for bool { + type Ret = bool; + fn and(self, other: impl FnOnce() -> bool) -> Self::Ret { + self && other() + } + fn or(self, other: impl FnOnce() -> bool) -> Self::Ret { + self || other() + } +} +impl BoolLazyOpsMaybeExpr for bool { + type Ret = Bool; + fn and(self, other: impl FnOnce() -> Bool) -> Self::Ret { + self & other() + } + fn or(self, other: impl FnOnce() -> Bool) -> Self::Ret { + self | other() + } +} +impl BoolLazyOpsMaybeExpr for Bool { + type Ret = Bool; + fn and(self, other: impl FnOnce() -> bool) -> Self::Ret { + self & other() + } + fn or(self, other: impl FnOnce() -> bool) -> Self::Ret { + self | other() + } +} +impl BoolLazyOpsMaybeExpr for Bool { + type Ret = Bool; + fn and(self, other: impl FnOnce() -> Bool) -> Self::Ret { + self & other() + } + fn or(self, other: impl FnOnce() -> Bool) -> Self::Ret { + self | other() + } +} + +pub trait EqMaybeExpr { + type Bool; + fn eq(self, other: X) -> Self::Bool; + fn ne(self, other: X) -> Self::Bool; +} +impl> EqMaybeExpr for A { + type Bool = bool; + fn eq(self, other: R) -> Self::Bool { + self == other + } + fn ne(self, other: R) -> Self::Bool { + self != other + } +} +macro_rules! impl_eme { + ($t: ty, $s: ty) => { + impl EqMaybeExpr<$s> for $t { + type Bool = <$t as VarTrait>::Bool; + fn eq(self, other: $s) -> Self::Bool { + self.cmpeq(other) + } + fn ne(self, other: $s) -> Self::Bool { + self.cmpne(other) + } + } + }; +} +macro_rules! impl_mem { + ($t: ty, $s: ty) => { + impl EqMaybeExpr<$s> for $t { + type Bool = <$s as VarTrait>::Bool; + fn eq(self, other: $s) -> Self::Bool { + other.cmpeq(self) + } + fn ne(self, other: $s) -> Self::Bool { + other.cmpne(self) + } + } + }; +} +macro_rules! emes { + ($x: ty $(, $y: ty)*) => { + impl_eme!(Expr<$x>, Expr<$x>); + impl_eme!(Expr<$x>, $x); + impl_mem!($x, Expr<$x>); + $(impl_eme!(Expr<$x>, $y); + impl_mem!($y, Expr<$x>);)* + }; +} +emes!(bool); +emes!(Bool2); +emes!(Bool3); +emes!(Bool4); + +pub trait PartialOrdMaybeExpr { + type Bool; + fn lt(self, other: R) -> Self::Bool; + fn le(self, other: R) -> Self::Bool; + fn ge(self, other: R) -> Self::Bool; + fn gt(self, other: R) -> Self::Bool; +} +impl> PartialOrdMaybeExpr for A { + type Bool = bool; + fn lt(self, other: R) -> Self::Bool { + self < other + } + fn le(self, other: R) -> Self::Bool { + self <= other + } + fn ge(self, other: R) -> Self::Bool { + self >= other + } + fn gt(self, other: R) -> Self::Bool { + self > other + } +} +macro_rules! impl_pome { + ($t: ty, $s: ty) => { + impl_eme!($t, $s); + impl PartialOrdMaybeExpr<$s> for $t { + type Bool = <$t as VarTrait>::Bool; + fn lt(self, other: $s) -> Self::Bool { + self.cmplt(other) + } + fn le(self, other: $s) -> Self::Bool { + self.cmple(other) + } + fn ge(self, other: $s) -> Self::Bool { + self.cmpge(other) + } + fn gt(self, other: $s) -> Self::Bool { + self.cmpgt(other) + } + } + }; +} +macro_rules! impl_emop { + ($t: ty, $s: ty) => { + impl_mem!($t, $s); + impl PartialOrdMaybeExpr<$s> for $t { + type Bool = <$s as VarTrait>::Bool; + fn lt(self, other: $s) -> Self::Bool { + other.cmpgt(self) + } + fn le(self, other: $s) -> Self::Bool { + other.cmpge(self) + } + fn ge(self, other: $s) -> Self::Bool { + other.cmplt(self) + } + fn gt(self, other: $s) -> Self::Bool { + other.cmplt(self) + } + } + }; +} +macro_rules! pomes { + ($x: ty $(, $y:ty)*) => { + impl_pome!(Expr<$x>, Expr<$x>); + impl_pome!(Expr<$x>, $x); + impl_emop!($x, Expr<$x>); + impl_pome!(Expr<$x>, Var<$x>); + impl_emop!(Var<$x>, Expr<$x>); + $(impl_pome!(Expr<$x>, $y); + impl_emop!($y, Expr<$x>);)* + }; +} +pomes!(f16); +pomes!(f32); +pomes!(f64); +pomes!(i16); +pomes!(i32); +pomes!(i64); +pomes!(u16); +pomes!(u32); +pomes!(u64); + +pomes!(Float2, Expr, f32); +pomes!(Float3, Expr, f32); +pomes!(Float4, Expr, f32); +pomes!(Double2); +pomes!(Double3); +pomes!(Double4); +pomes!(Int2, Expr); +pomes!(Int3, Expr); +pomes!(Int4, Expr); +pomes!(Uint2, Expr); +pomes!(Uint3, Expr); +pomes!(Uint4, Expr); + +#[allow(dead_code)] +fn tests() { + <_ as BoolWhileMaybeExpr>::while_loop(|| true, || {}); + <_ as BoolWhileMaybeExpr>::while_loop(|| Bool::from(true), || {}); +} diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs deleted file mode 100644 index b8681e7..0000000 --- a/luisa_compute/src/lang/mod.rs +++ /dev/null @@ -1,3106 +0,0 @@ -use std::backtrace::Backtrace; -use std::collections::HashSet; -use std::marker::PhantomData; -use std::{any::Any, collections::HashMap, fmt::Debug, rc::Rc, sync::Arc}; -use std::{env, unreachable}; - -use crate::lang::traits::VarCmp; -pub use crate::runtime::CallableArgEncoder; -use crate::runtime::{AsyncShaderArtifact, ShaderArtifact}; -use crate::*; -use crate::{rtx, ResourceTracker}; -use bumpalo::Bump; -use indexmap::IndexMap; -pub use ir::ir::NodeRef; -use ir::ir::{ - ArrayType, CallableModule, CallableModuleRef, ModuleFlags, ModulePools, SwitchCase, - UserNodeData, INVALID_REF, -}; -pub use ir::CArc; -use ir::Pooled; -use ir::{ - ir::{ - new_node, BasicBlock, Binding, Capture, Const, CpuCustomOp, Func, Instruction, IrBuilder, - KernelModule, Module, ModuleKind, Node, PhiIncoming, - }, - transform::{self, Transform}, -}; - -use luisa_compute_ir as ir; - -pub use luisa_compute_ir::{ - context::register_type, - ffi::CBoxedSlice, - ir::{StructType, Type}, - TypeOf, -}; -use math::Uint3; -use std::cell::{Cell, RefCell, UnsafeCell}; -use std::ffi::CString; -use std::ops::{Bound, Deref, DerefMut, RangeBounds}; -use std::sync::atomic::AtomicUsize; -// use self::math::Uint3; -pub mod math; -pub mod poly; -pub mod printer; -pub mod swizzle; -pub mod traits; - -pub use math::*; -pub use poly::*; -pub use printer::*; - -pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0); -// prevent node being shared across kernels -// TODO: replace NodeRef with SafeNodeRef -#[derive(Clone, Copy, Debug)] -pub(crate) struct SafeNodeRef { - pub(crate) node: NodeRef, - pub(crate) kernel_id: usize, -} -pub trait Value: Copy + ir::TypeOf + 'static { - type Expr: ExprProxy; - type Var: VarProxy; - fn fields() -> Vec; -} - -pub trait StructInitiaizable: Value { - type Init: Into; -} - -pub trait Aggregate: Sized { - fn to_vec_nodes(&self) -> Vec { - let mut nodes = vec![]; - Self::to_nodes(&self, &mut nodes); - nodes - } - fn from_vec_nodes(nodes: Vec) -> Self { - let mut iter = nodes.into_iter(); - let ret = Self::from_nodes(&mut iter); - assert!(iter.next().is_none()); - ret - } - fn to_nodes(&self, nodes: &mut Vec); - fn from_nodes>(iter: &mut I) -> Self; -} - -pub trait ToNode { - fn node(&self) -> NodeRef; -} - -pub trait FromNode: ToNode { - fn from_node(node: NodeRef) -> Self; -} - -fn _store(var: &T1, value: &T2) { - let value_nodes = value.to_vec_nodes(); - let self_nodes = var.to_vec_nodes(); - assert_eq!(value_nodes.len(), self_nodes.len()); - __current_scope(|b| { - for (value_node, self_node) in value_nodes.into_iter().zip(self_nodes.into_iter()) { - b.update(self_node, value_node); - } - }) -} - -#[inline(always)] -pub fn __new_user_node(data: T) -> NodeRef { - use luisa_compute_ir::ir::new_user_node; - new_user_node(__module_pools(), data) -} -macro_rules! impl_aggregate_for_tuple { - ()=>{ - impl Aggregate for () { - fn to_nodes(&self, _: &mut Vec) {} - fn from_nodes>(_: &mut I) -> Self{} - } - }; - ($first:ident $($rest:ident) *) => { - impl<$first:Aggregate, $($rest: Aggregate),*> Aggregate for ($first, $($rest,)*) { - #[allow(non_snake_case)] - fn to_nodes(&self, nodes: &mut Vec) { - let ($first, $($rest,)*) = self; - $first.to_nodes(nodes); - $($rest.to_nodes(nodes);)* - } - #[allow(non_snake_case)] - fn from_nodes>(iter: &mut I) -> Self { - let $first = Aggregate::from_nodes(iter); - $(let $rest = Aggregate::from_nodes(iter);)* - ($first, $($rest,)*) - } - } - impl_aggregate_for_tuple!($($rest)*); - }; - -} -impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); - -pub unsafe trait Mask: ToNode {} -pub trait IntoIndex { - fn to_u64(&self) -> Expr; -} -impl IntoIndex for i32 { - fn to_u64(&self) -> Expr { - const_(*self as u64) - } -} -impl IntoIndex for i64 { - fn to_u64(&self) -> Expr { - const_(*self as u64) - } -} -impl IntoIndex for u32 { - fn to_u64(&self) -> Expr { - const_(*self as u64) - } -} -impl IntoIndex for u64 { - fn to_u64(&self) -> Expr { - const_(*self) - } -} -impl IntoIndex for PrimExpr { - fn to_u64(&self) -> Expr { - self.ulong() - } -} -impl IntoIndex for PrimExpr { - fn to_u64(&self) -> Expr { - *self - } -} - -pub trait IndexRead: ToNode { - type Element: Value; - fn read(&self, i: I) -> Expr; -} - -pub trait IndexWrite: IndexRead { - fn write>>(&self, i: I, value: V); -} - -pub fn select(mask: impl Mask, a: A, b: A) -> A { - let a_nodes = a.to_vec_nodes(); - let b_nodes = b.to_vec_nodes(); - assert_eq!(a_nodes.len(), b_nodes.len()); - let mut ret = vec![]; - __current_scope(|b| { - for (a_node, b_node) in a_nodes.into_iter().zip(b_nodes.into_iter()) { - assert_eq!(a_node.type_(), b_node.type_()); - assert!(!a_node.is_local(), "cannot select local variables"); - assert!(!b_node.is_local(), "cannot select local variables"); - if a_node.is_user_data() || b_node.is_user_data() { - assert!( - a_node.is_user_data() && b_node.is_user_data(), - "cannot select user data and non-user data" - ); - let a_data = a_node.get_user_data(); - let b_data = b_node.get_user_data(); - if a_data != b_data { - panic!("cannot select different user data"); - } - ret.push(a_node); - } else { - ret.push(b.call( - Func::Select, - &[mask.node(), a_node, b_node], - a_node.type_().clone(), - )); - } - } - }); - A::from_vec_nodes(ret) -} - -impl ToNode for bool { - fn node(&self) -> NodeRef { - const_(*self).node() - } -} - -unsafe impl Mask for bool {} - -unsafe impl Mask for Bool {} - -pub trait ExprProxy: Copy + Aggregate + FromNode { - type Value: Value; -} - -pub struct VarDerefProxy -where - P: VarProxy, -{ - pub(crate) var: P, - pub(crate) dirty: Cell, - pub(crate) assigned: Expr, - pub(crate) _phantom: std::marker::PhantomData, -} - -impl Deref for VarDerefProxy -where - P: VarProxy, -{ - type Target = Expr; - fn deref(&self) -> &Self::Target { - &self.assigned - } -} - -impl DerefMut for VarDerefProxy -where - P: VarProxy, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - self.dirty.set(true); - &mut self.assigned - } -} - -impl Drop for VarDerefProxy -where - P: VarProxy, -{ - fn drop(&mut self) { - if self.dirty.get() { - self.var.store(self.assigned) - } - } -} -macro_rules! impl_assign_ops { - ($ass:ident, $ass_m:ident, $o:ident, $o_m:ident) => { - impl std::ops::$ass for VarDerefProxy - where - P: VarProxy, - ::Expr: std::ops::$o::Expr>, - { - fn $ass_m(&mut self, rhs: Rhs) { - *self.deref_mut() = std::ops::$o::$o_m(**self, rhs); - } - } - }; -} -impl_assign_ops!(AddAssign, add_assign, Add, add); -impl_assign_ops!(SubAssign, sub_assign, Sub, sub); -impl_assign_ops!(MulAssign, mul_assign, Mul, mul); -impl_assign_ops!(DivAssign, div_assign, Div, div); -impl_assign_ops!(RemAssign, rem_assign, Rem, rem); -impl_assign_ops!(BitAndAssign, bitand_assign, BitAnd, bitand); -impl_assign_ops!(BitOrAssign, bitor_assign, BitOr, bitor); -impl_assign_ops!(BitXorAssign, bitxor_assign, BitXor, bitxor); -impl_assign_ops!(ShlAssign, shl_assign, Shl, shl); -impl_assign_ops!(ShrAssign, shr_assign, Shr, shr); - -pub trait VarProxy: Copy + Aggregate + FromNode { - type Value: Value; - fn store>>(&self, value: U) { - let value = value.into(); - _store(self, &value); - } - fn load(&self) -> Expr { - __current_scope(|b| { - let nodes = self.to_vec_nodes(); - let mut ret = vec![]; - for node in nodes { - ret.push(b.call(Func::Load, &[node], node.type_().clone())); - } - Expr::::from_nodes(&mut ret.into_iter()) - }) - } - fn get_mut(&self) -> VarDerefProxy { - VarDerefProxy { - var: *self, - dirty: Cell::new(false), - assigned: self.load(), - _phantom: std::marker::PhantomData, - } - } - fn _deref<'a>(&'a self) -> &'a Expr { - RECORDER.with(|r| { - let v: Expr = self.load(); - let r = r.borrow(); - let v: &Expr = r.arena.alloc(v); - unsafe { - let v: &'a Expr = std::mem::transmute(v); - v - } - }) - } -} - -#[derive(Clone, Copy, Debug)] -pub struct PrimExpr { - pub(crate) node: NodeRef, - pub(crate) _phantom: std::marker::PhantomData, -} - -#[derive(Clone, Copy, Debug)] -pub struct PrimVar { - pub(crate) node: NodeRef, - pub(crate) _phantom: std::marker::PhantomData, -} - -impl Aggregate for PrimExpr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - _phantom: std::marker::PhantomData, - } - } -} - -impl Aggregate for PrimVar { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - _phantom: std::marker::PhantomData, - } - } -} -#[macro_export] -macro_rules! impl_callable_param { - ($t:ty, $e:ty, $v:ty) => { - impl CallableParameter for $e { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.value::<$t>() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) - } - } - impl CallableParameter for $v { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.var::<$t>() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) - } - } - }; -} - -macro_rules! impl_prim { - ($t:ty) => { - impl From<$t> for PrimExpr<$t> { - fn from(v: $t) -> Self { - const_(v) - } - } - impl FromNode for PrimVar<$t> { - fn from_node(node: NodeRef) -> Self { - Self { - node, - _phantom: std::marker::PhantomData, - } - } - } - impl ToNode for PrimVar<$t> { - fn node(&self) -> NodeRef { - self.node - } - } - impl ExprProxy for PrimExpr<$t> { - type Value = $t; - } - impl VarProxy for PrimVar<$t> { - type Value = $t; - } - impl Deref for PrimVar<$t> { - type Target = PrimExpr<$t>; - fn deref(&self) -> &Self::Target { - self._deref() - } - } - impl Value for $t { - type Expr = PrimExpr<$t>; - type Var = PrimVar<$t>; - fn fields() -> Vec { - vec![] - } - } - impl_callable_param!($t, PrimExpr<$t>, PrimVar<$t>); - }; -} - -impl_prim!(bool); -impl_prim!(u32); -impl_prim!(u64); -impl_prim!(i32); -impl_prim!(i64); -impl_prim!(i16); -impl_prim!(u16); -impl_prim!(f16); -impl_prim!(f32); -impl_prim!(f64); - -pub type Bool = PrimExpr; -pub type F16 = PrimExpr; -pub type F32 = PrimExpr; -pub type F64 = PrimExpr; -pub type I16 = PrimExpr; -pub type I32 = PrimExpr; -pub type I64 = PrimExpr; -pub type U16 = PrimExpr; -pub type U32 = PrimExpr; -pub type U64 = PrimExpr; - -pub type F16Var = PrimVar; -pub type F32Var = PrimVar; -pub type F64Var = PrimVar; -pub type I16Var = PrimVar; -pub type I32Var = PrimVar; -pub type I64Var = PrimVar; -pub type U16Var = PrimVar; -pub type U32Var = PrimVar; -pub type U64Var = PrimVar; - -pub type Half = PrimExpr; -pub type Float = PrimExpr; -pub type Double = PrimExpr; -pub type Int = PrimExpr; -pub type Long = PrimExpr; -pub type Uint = PrimExpr; -pub type Ulong = PrimExpr; -pub type Short = PrimExpr; -pub type Ushort = PrimExpr; - -pub type BoolVar = PrimVar; -pub type HalfVar = PrimVar; -pub type FloatVar = PrimVar; -pub type DoubleVar = PrimVar; -pub type IntVar = PrimVar; -pub type LongVar = PrimVar; -pub type UintVar = PrimVar; -pub type UlongVar = PrimVar; -pub type ShortVar = PrimVar; -pub type UshortVar = PrimVar; - -pub struct CpuFn { - op: CArc, - _marker: std::marker::PhantomData, -} -#[macro_export] -macro_rules! cpu_dbg { - ($arg:expr) => {{ - $crate::lang::__cpu_dbg($arg, file!(), line!()) - }}; -} -#[macro_export] -macro_rules! lc_dbg { - ($arg:expr) => {{ - $crate::lang::__cpu_dbg($arg, file!(), line!()) - }}; -} -#[macro_export] -macro_rules! lc_unreachable { - () => { - $crate::lang::__unreachable(file!(), line!(), column!()) - }; -} -#[macro_export] -macro_rules! lc_assert { - ($arg:expr) => { - __assert($arg, stringify!($arg), file!(), line!(), column!()) - }; - ($arg:expr, $msg:expr) => { - __assert($arg, $msg, file!(), line!(), column!()) - }; -} -pub fn __cpu_dbg(arg: T, file: &'static str, line: u32) -where - T::Value: Debug, -{ - if !is_cpu_backend() { - return; - } - let f = CpuFn::new(move |x: &mut T::Value| { - println!("[{}:{}] {:?}", file, line, x); - }); - let _ = f.call(arg); -} - -extern "C" fn _trampoline(data: *mut u8, args: *mut u8) { - unsafe { - let container = &*(data as *const ClosureContainer); - let f = &container.f; - let args = &mut *(args as *mut T); - f(args); - } -} - -extern "C" fn _drop(data: *mut u8) { - unsafe { - let _ = Box::from_raw(data as *mut T); - } -} -/* -Interestingly, Box::into_raw(Box) does not give a valid pointer. -*/ -struct ClosureContainer { - f: Arc, -} - -impl CpuFn { - pub fn new(f: F) -> Self { - let f_ptr = Box::into_raw(Box::new(ClosureContainer:: { f: Arc::new(f) })); - let op = CpuCustomOp { - data: f_ptr as *mut u8, - func: _trampoline::, - destructor: _drop::, - arg_type: T::type_(), - }; - Self { - op: CArc::new(op), - _marker: std::marker::PhantomData, - } - } - pub fn call(&self, arg: impl ExprProxy) -> Expr { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!(r.lock); - assert_eq!( - r.device - .as_ref() - .unwrap() - .upgrade() - .unwrap() - .inner - .query("device_name") - .unwrap(), - "cpu", - "CpuFn can only be used in cpu backend" - ); - let addr = CArc::as_ptr(&self.op) as u64; - if let Some((_, op)) = r.cpu_custom_ops.get(&addr) { - assert_eq!(CArc::as_ptr(op), CArc::as_ptr(&self.op)); - } else { - let i = r.cpu_custom_ops.len(); - r.cpu_custom_ops.insert(addr, (i, self.op.clone())); - } - }); - Expr::::from_node(__current_scope(|b| { - b.call( - Func::CpuCustomOp(self.op.clone()), - &[arg.node()], - T::type_(), - ) - })) - } -} - -pub(crate) struct Recorder { - pub(crate) scopes: Vec, - pub(crate) kernel_id: Option, - pub(crate) lock: bool, - pub(crate) captured_buffer: IndexMap)>, - pub(crate) cpu_custom_ops: IndexMap)>, - pub(crate) callables: IndexMap, - pub(crate) shared: Vec, - pub(crate) device: Option, - pub(crate) block_size: Option<[u32; 3]>, - pub(crate) building_kernel: bool, - pub(crate) pools: Option>, - pub(crate) arena: Bump, - pub(crate) callable_ret_type: Option>, -} - -impl Recorder { - fn reset(&mut self) { - self.scopes.clear(); - self.captured_buffer.clear(); - self.cpu_custom_ops.clear(); - self.callables.clear(); - self.lock = false; - self.device = None; - self.block_size = None; - self.arena.reset(); - self.shared.clear(); - self.kernel_id = None; - self.callable_ret_type = None; - } - pub(crate) fn new() -> Self { - Recorder { - scopes: vec![], - lock: false, - captured_buffer: IndexMap::new(), - cpu_custom_ops: IndexMap::new(), - callables: IndexMap::new(), - shared: vec![], - device: None, - block_size: None, - pools: None, - arena: Bump::new(), - building_kernel: false, - kernel_id: None, - callable_ret_type: None, - } - } -} -thread_local! { - pub(crate) static RECORDER: RefCell = RefCell::new(Recorder::new()); -} - -// Don't call this function directly unless you know what you are doing -pub fn __current_scope R, R>(f: F) -> R { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!(r.lock, "__current_scope must be called within a kernel"); - let s = &mut r.scopes; - f(s.last_mut().unwrap()) - }) -} - -pub(crate) fn __invoke_callable(callable: &CallableModuleRef, args: &[NodeRef]) -> NodeRef { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let id = CArc::as_ptr(&callable.0) as u64; - if let Some(c) = r.callables.get(&id) { - assert_eq!(CArc::as_ptr(&c.0), CArc::as_ptr(&callable.0)); - } else { - r.callables.insert(id, callable.clone()); - } - }); - __current_scope(|b| { - b.call( - Func::Callable(callable.clone()), - args, - callable.0.ret_type.clone(), - ) - }) -} - -pub(crate) fn __check_node_type(a: NodeRef, b: NodeRef) -> bool { - if !ir::context::is_type_equal(a.type_(), b.type_()) { - return false; - } - match (a.get().instruction.as_ref(), b.get().instruction.as_ref()) { - (Instruction::Buffer, Instruction::Buffer) => true, - (Instruction::Texture2D, Instruction::Texture2D) => true, - (Instruction::Texture3D, Instruction::Texture3D) => true, - (Instruction::Bindless, Instruction::Bindless) => true, - (Instruction::Accel, Instruction::Accel) => true, - (Instruction::Uniform, Instruction::Uniform) => true, - (Instruction::Local { .. }, Instruction::Local { .. }) => true, - (Instruction::Argument { by_value: true }, _) => b.get().instruction.has_value(), - (Instruction::Argument { by_value: false }, _) => b.is_lvalue(), - _ => false, - } -} - -pub(crate) fn __check_callable(callable: &CallableModuleRef, args: &[NodeRef]) -> bool { - assert_eq!(callable.0.args.len(), args.len()); - for i in 0..args.len() { - if !__check_node_type(callable.0.args[i], args[i]) { - return false; - } - } - true -} - -// Don't call this function directly unless you know what you are doing -pub fn __pop_scope() -> Pooled { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let s = &mut r.scopes; - s.pop().unwrap().finish() - }) -} - -pub fn __module_pools() -> &'static CArc { - RECORDER.with(|r| { - let r = r.borrow(); - assert!(r.lock, "__module_pools must be called within a kernel"); - let pool = r.pools.as_ref().unwrap(); - unsafe { std::mem::transmute(pool) } - }) -} -// pub fn __load(node: NodeRef) -> Expr { -// __current_scope(|b| { -// let node = b.load(node); -// Expr::::from_node(node) -// }) -// } -// pub fn __store(var:NodeRef, value:NodeRef) { -// let inst = &var.get().instruction; -// } - -pub fn __extract(node: NodeRef, index: usize) -> NodeRef { - let inst = &node.get().instruction; - __current_scope(|b| { - let i = b.const_(Const::Int32(index as i32)); - let op = match inst.as_ref() { - Instruction::Local { .. } => Func::GetElementPtr, - Instruction::Argument { by_value } => { - if *by_value { - Func::ExtractElement - } else { - Func::GetElementPtr - } - } - Instruction::Call(f, args) => match f { - Func::AtomicRef => { - let mut indices = args.to_vec(); - indices.push(i); - return b.call(Func::AtomicRef, &indices, ::type_()); - } - _ => Func::ExtractElement, - }, - _ => Func::ExtractElement, - }; - let node = b.call(op, &[node, i], ::type_()); - node - }) -} - -pub fn __insert(node: NodeRef, index: usize, value: NodeRef) -> NodeRef { - let inst = &node.get().instruction; - __current_scope(|b| { - let i = b.const_(Const::Int32(index as i32)); - let op = match inst.as_ref() { - Instruction::Local { .. } => panic!("Can't insert into local variable"), - _ => Func::InsertElement, - }; - let node = b.call(op, &[node, value, i], ::type_()); - node - }) -} - -pub fn __compose(nodes: &[NodeRef]) -> NodeRef { - let ty = ::type_(); - match ty.as_ref() { - Type::Struct(st) => { - assert_eq!(st.fields.as_ref().len(), nodes.len()); - __current_scope(|b| b.call(Func::Struct, nodes, ::type_())) - } - Type::Primitive(_) => panic!("Can't compose primitive type"), - Type::Vector(vt) => { - let length = vt.length; - let func = match length { - 2 => Func::Vec2, - 3 => Func::Vec3, - 4 => Func::Vec4, - _ => panic!("Can't compose vector with length {}", length), - }; - __current_scope(|b| b.call(func, nodes, ::type_())) - } - Type::Matrix(vt) => { - let length = vt.dimension; - let func = match length { - 2 => Func::Mat2, - 3 => Func::Mat3, - 4 => Func::Mat4, - _ => panic!("Can't compose vector with length {}", length), - }; - __current_scope(|b| b.call(func, nodes, ::type_())) - } - _ => todo!(), - } -} -#[macro_export] -macro_rules! struct_ { - ($t:ty { $($it:ident : $value:expr), * $(,)?}) =>{ - { - type Init = <$t as $crate::lang::StructInitiaizable>::Init; - let init = Init { $($it : $value), * }; - type Expr = <$t as $crate::lang::Value>::Expr; - let e:Expr = init.into(); - e - } - } -} -#[macro_export] -macro_rules! var { - ($t:ty) => { - $crate::lang::local_zeroed::<$t>() - }; - ($t:ty, 0) => { - $crate::lang::local_zeroed::<$t>() - }; - ($t:ty, $init:expr) => { - $crate::lang::local::<$t>($init.into()) - }; - ($e:expr) => { - $crate::lang::def($e) - }; -} -pub fn def, T: Value>(init: E) -> Var { - Var::::from_node(__current_scope(|b| b.local(init.node()))) -} -pub fn local(init: Expr) -> Var { - Var::::from_node(__current_scope(|b| b.local(init.node()))) -} - -pub fn local_zeroed() -> Var { - Var::::from_node(__current_scope(|b| { - b.local_zero_init(::type_()) - })) -} - -pub fn thread_id() -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::ThreadId, &[], Uint3::type_()) - })) -} - -pub fn block_id() -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::BlockId, &[], Uint3::type_()) - })) -} - -pub fn dispatch_id() -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::DispatchId, &[], Uint3::type_()) - })) -} - -pub fn dispatch_size() -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::DispatchSize, &[], Uint3::type_()) - })) -} -fn check_block_size_for_cpu() { - RECORDER.with(|r| { - let r = r.borrow(); - assert!( - r.block_size.is_some(), - "CPU backend only support block operations on block size 1" - ); - let size = r.block_size.unwrap(); - assert_eq!( - size, - [1, 1, 1], - "CPU backend only support block operations on block size 1" - ); - }); -} -pub fn sync_block() { - if is_cpu_backend() { - check_block_size_for_cpu(); - return; - } - __current_scope(|b| { - b.call(Func::SynchronizeBlock, &[], Type::void()); - }) -} - -pub fn warp_is_first_active_lane() -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::WarpIsFirstActiveLane, &[], Bool::type_()) - })) -} -pub fn warp_active_all_equal(v: impl ScalarOrVector) -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveAllEqual, - &[v.node()], - ::type_(), - ) - })) -} -pub fn warp_active_bit_and, E: IntVarTrait>(v: T) -> T { - T::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveBitAnd, - &[v.node()], - ::type_(), - ) - })) -} - -pub fn warp_active_bit_or, E: IntVarTrait>(v: T) -> T { - T::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveBitOr, - &[v.node()], - ::type_(), - ) - })) -} - -pub fn warp_active_bit_xor, E: IntVarTrait>(v: T) -> T { - T::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveBitXor, - &[v.node()], - ::type_(), - ) - })) -} - -pub fn warp_active_count_bits(v: impl Into>) -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveCountBits, - &[v.into().node()], - ::type_(), - ) - })) -} -pub fn warp_active_max(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveMax, &[v.node()], ::type_()) - })) -} -pub fn warp_active_min(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveMin, &[v.node()], ::type_()) - })) -} -pub fn warp_active_product(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveProduct, - &[v.node()], - ::type_(), - ) - })) -} -pub fn warp_active_sum(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveSum, &[v.node()], ::type_()) - })) -} -pub fn warp_active_all(v: Expr) -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::WarpActiveAll, &[v.node()], ::type_()) - })) -} -pub fn warp_active_any(v: Expr) -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::WarpActiveAny, &[v.node()], ::type_()) - })) -} -pub fn warp_active_bit_mask() -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call(Func::WarpActiveBitMask, &[], ::type_()) - })) -} -pub fn warp_prefix_count_bits(v: Expr) -> Expr { - Expr::::from_node(__current_scope(|b| { - b.call( - Func::WarpPrefixCountBits, - &[v.node()], - ::type_(), - ) - })) -} -pub fn warp_prefix_sum_exclusive(v: T) -> T { - T::from_node(__current_scope(|b| { - b.call(Func::WarpPrefixSum, &[v.node()], v.node().type_().clone()) - })) -} -pub fn warp_prefix_product_exclusive(v: T) -> T { - T::from_node(__current_scope(|b| { - b.call( - Func::WarpPrefixProduct, - &[v.node()], - v.node().type_().clone(), - ) - })) -} -pub fn warp_read_lane_at(v: T, index: impl Into>) -> T { - let index = index.into(); - T::from_node(__current_scope(|b| { - b.call( - Func::WarpReadLaneAt, - &[v.node(), index.node()], - v.node().type_().clone(), - ) - })) -} -pub fn warp_read_first_active_lane(v: T) -> T { - T::from_node(__current_scope(|b| { - b.call( - Func::WarpReadFirstLane, - &[v.node()], - v.node().type_().clone(), - ) - })) -} -pub fn set_block_size(size: [u32; 3]) { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!( - r.building_kernel, - "set_block_size cannot be called in callable!" - ); - assert!(r.block_size.is_none(), "Block size already set"); - - r.block_size = Some(size); - }); -} - -pub fn block_size() -> Expr { - RECORDER.with(|r| { - let r = r.borrow(); - let s = r.block_size.unwrap_or_else(|| panic!("Block size not set")); - const_::(Uint3::new(s[0], s[1], s[2])) - }) -} - -pub type Expr = ::Expr; -pub type Var = ::Var; - -pub fn zeroed() -> T::Expr { - FromNode::from_node(__current_scope(|b| b.zero_initializer(T::type_()))) -} - -pub fn const_(value: T) -> T::Expr { - let node = __current_scope(|s| -> NodeRef { - let any = &value as &dyn Any; - if let Some(value) = any.downcast_ref::() { - s.const_(Const::Bool(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Int32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Uint32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Int64(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Uint64(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Float32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Float64(*value)) - } else { - let mut buf = vec![0u8; std::mem::size_of::()]; - unsafe { - std::ptr::copy_nonoverlapping( - &value as *const T as *const u8, - buf.as_mut_ptr(), - buf.len(), - ); - } - s.const_(Const::Generic(CBoxedSlice::new(buf), T::type_())) - } - }); - FromNode::from_node(node) -} - -pub fn bitcast(expr: Expr) -> Expr { - assert_eq!(std::mem::size_of::(), std::mem::size_of::()); - Expr::::from_node(__current_scope(|b| { - b.call(Func::Bitcast, &[expr.node()], ::type_()) - })) -} - -pub const fn packed_size() -> usize { - (std::mem::size_of::() + 3) / 4 -} - -pub fn pack_to(expr: E, buffer: &B, index: impl Into>) -where - E: ExprProxy, - B: IndexWrite, -{ - let index = index.into(); - __current_scope(|b| { - b.call( - Func::Pack, - &[expr.node(), buffer.node(), index.node()], - Type::void(), - ); - }); -} - -pub fn unpack_from( - buffer: &impl IndexWrite, - index: impl Into>, -) -> Expr -where - T: Value, -{ - let index = index.into(); - Expr::::from_node(__current_scope(|b| { - b.call( - Func::Unpack, - &[buffer.node(), index.node()], - ::type_(), - ) - })) -} - -impl Value for [T; N] { - type Expr = ArrayExpr; - type Var = ArrayVar; - fn fields() -> Vec { - todo!("why this method exists?") - } -} - -#[derive(Clone, Copy)] -pub struct DynExpr { - node: NodeRef, -} - -impl From for DynExpr { - fn from(value: T) -> Self { - Self { node: value.node() } - } -} - -impl From for DynVar { - fn from(value: T) -> Self { - Self { node: value.node() } - } -} - -impl DynExpr { - pub fn downcast(&self) -> Option> { - if ir::context::is_type_equal(self.node.type_(), &T::type_()) { - Some(Expr::::from_node(self.node)) - } else { - None - } - } - pub fn get(&self) -> Expr { - self.downcast::().unwrap_or_else(|| { - panic!( - "DynExpr::get: type mismatch: expected {}, got {}", - std::any::type_name::(), - self.node.type_().to_string() - ) - }) - } - pub fn downcast_array(&self, len: usize) -> Option> { - let array_type = ir::context::register_type(Type::Array(ArrayType { - element: T::type_(), - length: len, - })); - if ir::context::is_type_equal(self.node.type_(), &array_type) { - Some(VLArrayExpr::::from_node(self.node)) - } else { - None - } - } - pub fn get_array(&self, len: usize) -> VLArrayExpr { - let array_type = ir::context::register_type(Type::Array(ArrayType { - element: T::type_(), - length: len, - })); - self.downcast_array::(len).unwrap_or_else(|| { - panic!( - "DynExpr::get: type mismatch: expected {}, got {}", - array_type, - self.node.type_().to_string() - ) - }) - } - pub fn new(expr: E) -> Self { - Self { node: expr.node() } - } -} - -impl CallableParameter for DynExpr { - fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self { - let arg = arg.unwrap_or_else(|| panic!("DynExpr should be used in DynCallable only!")); - let arg = arg.downcast_ref::().unwrap(); - let node = builder.arg(arg.node.type_().clone(), true); - Self { node } - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.args.push(self.node) - } -} - -impl Aggregate for DynExpr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node) - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } -} - -impl FromNode for DynExpr { - fn from_node(node: NodeRef) -> Self { - Self { node } - } -} - -impl ToNode for DynExpr { - fn node(&self) -> NodeRef { - self.node - } -} - -unsafe impl CallableRet for DynExpr { - fn _return(&self) -> CArc { - __current_scope(|b| { - b.return_(self.node); - }); - self.node.type_().clone() - } - fn _from_return(node: NodeRef) -> Self { - Self::from_node(node) - } -} - -impl Aggregate for DynVar { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node) - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } -} - -impl FromNode for DynVar { - fn from_node(node: NodeRef) -> Self { - Self { node } - } -} - -impl ToNode for DynVar { - fn node(&self) -> NodeRef { - self.node - } -} - -#[derive(Clone, Copy)] -pub struct DynVar { - node: NodeRef, -} - -impl CallableParameter for DynVar { - fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self { - let arg = arg.unwrap_or_else(|| panic!("DynVar should be used in DynCallable only!")); - let arg = arg.downcast_ref::().unwrap(); - let node = builder.arg(arg.node.type_().clone(), false); - Self { node } - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.args.push(self.node) - } -} - -impl DynVar { - pub fn downcast(&self) -> Option> { - if ir::context::is_type_equal(self.node.type_(), &T::type_()) { - Some(Var::::from_node(self.node)) - } else { - None - } - } - pub fn get(&self) -> Var { - self.downcast::().unwrap_or_else(|| { - panic!( - "DynVar::get: type mismatch: expected {}, got {}", - std::any::type_name::(), - self.node.type_().to_string() - ) - }) - } - pub fn downcast_array(&self, len: usize) -> Option> { - let array_type = ir::context::register_type(Type::Array(ArrayType { - element: T::type_(), - length: len, - })); - if ir::context::is_type_equal(self.node.type_(), &array_type) { - Some(VLArrayVar::::from_node(self.node)) - } else { - None - } - } - pub fn get_array(&self, len: usize) -> VLArrayVar { - let array_type = ir::context::register_type(Type::Array(ArrayType { - element: T::type_(), - length: len, - })); - self.downcast_array::(len).unwrap_or_else(|| { - panic!( - "DynExpr::get: type mismatch: expected {}, got {}", - array_type, - self.node.type_().to_string() - ) - }) - } - pub fn load(&self) -> DynExpr { - DynExpr { - node: __current_scope(|b| b.call(Func::Load, &[self.node], self.node.type_().clone())), - } - } - pub fn store(&self, value: &DynExpr) { - __current_scope(|b| b.update(self.node, value.node)); - } - pub fn zero() -> Self { - let v = local_zeroed::(); - Self { node: v.node() } - } -} -pub struct Shared { - marker: std::marker::PhantomData, - node: NodeRef, -} -impl Shared { - pub fn new(length: usize) -> Self { - Self { - marker: std::marker::PhantomData, - node: __current_scope(|b| { - let shared = new_node( - b.pools(), - Node::new( - CArc::new(Instruction::Shared), - ir::context::register_type(Type::Array(ArrayType { - element: T::type_(), - length, - })), - ), - ); - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - r.shared.push(shared); - }); - shared - }), - } - } - pub fn len(&self) -> Expr { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => const_(*length as u64), - _ => unreachable!(), - } - } - pub fn static_len(&self) -> usize { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => *length, - _ => unreachable!(), - } - } - pub fn write>>(&self, i: I, value: V) { - let i = i.to_u64(); - let value = value.into(); - - if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); - } - - __current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.update(gep, value.node()); - }); - } - pub fn load(&self) -> VLArrayExpr { - VLArrayExpr::from_node(__current_scope(|b| { - b.call(Func::Load, &[self.node], self.node.type_().clone()) - })) - } - pub fn store(&self, value: VLArrayExpr) { - __current_scope(|b| { - b.update(self.node, value.node); - }); - } -} -#[derive(Clone, Copy, Debug)] -pub struct VLArrayExpr { - marker: std::marker::PhantomData, - node: NodeRef, -} - -impl FromNode for VLArrayExpr { - fn from_node(node: NodeRef) -> Self { - Self { - marker: std::marker::PhantomData, - node, - } - } -} - -impl ToNode for VLArrayExpr { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for VLArrayExpr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -#[derive(Clone, Copy, Debug)] -pub struct VLArrayVar { - marker: std::marker::PhantomData, - node: NodeRef, -} - -impl FromNode for VLArrayVar { - fn from_node(node: NodeRef) -> Self { - Self { - marker: std::marker::PhantomData, - node, - } - } -} - -impl ToNode for VLArrayVar { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for VLArrayVar { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -impl VLArrayVar { - pub fn read>>(&self, i: I) -> Expr { - let i = i.into(); - if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); - } - - Expr::::from_node(__current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.call(Func::Load, &[gep], T::type_()) - })) - } - pub fn len(&self) -> Expr { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => const_(*length as u32), - _ => unreachable!(), - } - } - pub fn static_len(&self) -> usize { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => *length, - _ => unreachable!(), - } - } - pub fn write>, V: Into>>(&self, i: I, value: V) { - let i = i.into(); - let value = value.into(); - - if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); - } - - __current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.update(gep, value.node()); - }); - } - pub fn load(&self) -> VLArrayExpr { - VLArrayExpr::from_node(__current_scope(|b| { - b.call(Func::Load, &[self.node], self.node.type_().clone()) - })) - } - pub fn store(&self, value: VLArrayExpr) { - __current_scope(|b| { - b.update(self.node, value.node); - }); - } - pub fn zero(length: usize) -> Self { - FromNode::from_node(__current_scope(|b| { - b.local_zero_init(ir::context::register_type(Type::Array(ArrayType { - element: T::type_(), - length, - }))) - })) - } -} - -impl VLArrayExpr { - pub fn zero(length: usize) -> Self { - let node = __current_scope(|b| { - b.call( - Func::ZeroInitializer, - &[], - ir::context::register_type(Type::Array(ArrayType { - element: T::type_(), - length, - })), - ) - }); - Self::from_node(node) - } - pub fn static_len(&self) -> usize { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => *length, - _ => unreachable!(), - } - } - pub fn read(&self, i: I) -> Expr { - let i = i.to_u64(); - if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); - } - - Expr::::from_node(__current_scope(|b| { - b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) - })) - } - pub fn len(&self) -> Expr { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => const_(*length as u64), - _ => unreachable!(), - } - } -} - -impl IndexRead for ArrayExpr { - type Element = T; - fn read(&self, i: I) -> Expr { - let i = i.to_u64(); - - lc_assert!(i.cmplt(const_(N as u64))); - - Expr::::from_node(__current_scope(|b| { - b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) - })) - } -} - -impl IndexRead for ArrayVar { - type Element = T; - fn read(&self, i: I) -> Expr { - let i = i.to_u64(); - if need_runtime_check() { - lc_assert!(i.cmplt(const_(N as u64))); - } - - Expr::::from_node(__current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.call(Func::Load, &[gep], T::type_()) - })) - } -} - -impl IndexWrite for ArrayVar { - fn write>>(&self, i: I, value: V) { - let i = i.to_u64(); - let value = value.into(); - - if need_runtime_check() { - lc_assert!(i.cmplt(const_(N as u64))); - } - - __current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.update(gep, value.node()); - }); - } -} - -#[derive(Clone, Copy, Debug)] -pub struct ArrayExpr { - marker: std::marker::PhantomData, - node: NodeRef, -} - -#[derive(Clone, Copy, Debug)] -pub struct ArrayVar { - marker: std::marker::PhantomData, - node: NodeRef, -} - -impl FromNode for ArrayExpr { - fn from_node(node: NodeRef) -> Self { - Self { - marker: std::marker::PhantomData, - node, - } - } -} - -impl ToNode for ArrayExpr { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for ArrayExpr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -impl FromNode for ArrayVar { - fn from_node(node: NodeRef) -> Self { - Self { - marker: std::marker::PhantomData, - node, - } - } -} - -impl ToNode for ArrayVar { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for ArrayVar { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -impl ExprProxy for ArrayExpr { - type Value = [T; N]; -} - -impl VarProxy for ArrayVar { - type Value = [T; N]; -} - -impl ArrayVar { - pub fn len(&self) -> Expr { - const_(N as u32) - } -} - -impl ArrayExpr { - pub fn zero() -> Self { - let node = __current_scope(|b| b.call(Func::ZeroInitializer, &[], <[T; N]>::type_())); - Self::from_node(node) - } - pub fn len(&self) -> Expr { - const_(N as u32) - } -} - -// Not recommended to use this directly -pub struct KernelBuilder { - device: Option, - args: Vec, -} - -pub trait CallableParameter: Sized + Clone + 'static { - fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self; - fn encode(&self, encoder: &mut CallableArgEncoder); -} -macro_rules! impl_callable_parameter_for_tuple { - ()=>{ - impl CallableParameter for () { - fn def_param(_: Option>, _: &mut KernelBuilder) {} - fn encode(&self, _: &mut CallableArgEncoder) { } - } - }; - ($first:ident $($rest:ident) *) => { - impl<$first:CallableParameter, $($rest: CallableParameter),*> CallableParameter for ($first, $($rest,)*) { - #[allow(non_snake_case)] - fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self { - if let Some(arg) = arg { - let ($first, $($rest,)*) = arg.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(std::rc::Rc::new($first)), builder); - let ($($rest,)*) = ($($rest::def_param(Some(std::rc::Rc::new($rest)), builder),)*); - ($first, $($rest,)*) - }else { - let $first = $first::def_param(None, builder); - let ($($rest,)*) = ($($rest::def_param(None, builder),)*); - ($first, $($rest,)*) - } - } - #[allow(non_snake_case)] - fn encode(&self, encoder: &mut CallableArgEncoder) { - let ($first, $($rest,)*) = self; - $first.encode(encoder); - $($rest.encode(encoder);)* - } - } - impl_callable_parameter_for_tuple!($($rest)*); - }; - -} -impl_callable_parameter_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); - -impl CallableParameter for BufferVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.buffer() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.buffer(self) - } -} -impl CallableParameter for ByteBufferVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.byte_buffer() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.byte_buffer(self) - } -} -impl CallableParameter for Tex2dVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.tex2d() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.tex2d(self) - } -} - -impl CallableParameter for Tex3dVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.tex3d() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.tex3d(self) - } -} - -impl CallableParameter for BindlessArrayVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.bindless_array() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.bindless_array(self) - } -} - -impl CallableParameter for rtx::AccelVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.accel() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.accel(self) - } -} - -pub trait KernelParameter { - fn def_param(builder: &mut KernelBuilder) -> Self; -} - -impl KernelParameter for U -where - U: ExprProxy, - T: Value, -{ - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.uniform::() - } -} -impl KernelParameter for ByteBufferVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.byte_buffer() - } -} -impl KernelParameter for BufferVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.buffer() - } -} - -impl KernelParameter for Tex2dVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.tex2d() - } -} - -impl KernelParameter for Tex3dVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.tex3d() - } -} - -impl KernelParameter for BindlessArrayVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.bindless_array() - } -} - -impl KernelParameter for rtx::AccelVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.accel() - } -} -macro_rules! impl_kernel_param_for_tuple { - ($first:ident $($rest:ident)*) => { - impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelParameter for ($first, $($rest,)*) { - #[allow(non_snake_case)] - fn def_param(builder: &mut KernelBuilder) -> Self { - ($first::def_param(builder), $($rest::def_param(builder)),*) - } - } - impl_kernel_param_for_tuple!($($rest)*); - }; - ()=>{ - impl KernelParameter for () { - fn def_param(_: &mut KernelBuilder) -> Self { - } - } - } -} -impl_kernel_param_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -impl KernelBuilder { - pub fn new(device: Option, is_kernel: bool) -> Self { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!(!r.lock, "Cannot record multiple kernels at the same time"); - assert!( - r.scopes.is_empty(), - "Cannot record multiple kernels at the same time" - ); - r.lock = true; - r.device = device.as_ref().map(|d| WeakDevice::new(d)); - r.pools = Some(CArc::new(ModulePools::new())); - r.scopes.clear(); - r.building_kernel = is_kernel; - let pools = r.pools.clone().unwrap(); - r.scopes.push(IrBuilder::new(pools)); - }); - Self { - device, - args: vec![], - } - } - pub(crate) fn arg(&mut self, ty: CArc, by_value: bool) -> NodeRef { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Argument { by_value }), ty), - ); - self.args.push(node); - node - } - pub fn value(&mut self) -> Expr { - let node = self.arg(T::type_(), true); - FromNode::from_node(node) - } - pub fn var(&mut self) -> Var { - let node = self.arg(T::type_(), false); - FromNode::from_node(node) - } - pub fn uniform(&mut self) -> Expr { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Uniform), T::type_()), - ); - self.args.push(node); - FromNode::from_node(node) - } - pub fn byte_buffer(&mut self) -> ByteBufferVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Buffer), Type::void()), - ); - self.args.push(node); - ByteBufferVar { node, handle: None } - } - pub fn buffer(&mut self) -> BufferVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Buffer), T::type_()), - ); - self.args.push(node); - BufferVar { - node, - marker: std::marker::PhantomData, - handle: None, - } - } - pub fn tex2d(&mut self) -> Tex2dVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Texture2D), T::type_()), - ); - self.args.push(node); - Tex2dVar { - node, - marker: std::marker::PhantomData, - handle: None, - level: None, - } - } - pub fn tex3d(&mut self) -> Tex3dVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Texture3D), T::type_()), - ); - self.args.push(node); - Tex3dVar { - node, - marker: std::marker::PhantomData, - handle: None, - level: None, - } - } - pub fn bindless_array(&mut self) -> BindlessArrayVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Bindless), Type::void()), - ); - self.args.push(node); - BindlessArrayVar { node, handle: None } - } - pub fn accel(&mut self) -> rtx::AccelVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Accel), Type::void()), - ); - self.args.push(node); - rtx::AccelVar { node, handle: None } - } - fn collect_module_info(&self) -> (ResourceTracker, Vec>, Vec) { - RECORDER.with(|r| { - let mut resource_tracker = ResourceTracker::new(); - let r = r.borrow_mut(); - let mut captured: Vec = Vec::new(); - let mut captured_buffers: Vec<_> = r.captured_buffer.values().cloned().collect(); - captured_buffers.sort_by_key(|(i, _, _, _)| *i); - for (j, (i, node, binding, handle)) in captured_buffers.into_iter().enumerate() { - assert_eq!(j, i); - captured.push(Capture { node, binding }); - resource_tracker.add_any(handle); - } - let mut cpu_custom_ops: Vec<_> = r.cpu_custom_ops.values().cloned().collect(); - cpu_custom_ops.sort_by_key(|(i, _)| *i); - let mut cpu_custom_ops: Vec> = cpu_custom_ops - .iter() - .enumerate() - .map(|(j, (i, op))| { - assert_eq!(j, *i); - (*op).clone() - }) - .collect::>(); - let callables: Vec = r.callables.values().cloned().collect(); - let mut captured_set = HashSet::::new(); - let mut cpu_custom_ops_set = HashSet::::new(); - let mut callable_set = HashSet::::new(); - for capture in captured.iter() { - captured_set.insert(*capture); - } - for op in &cpu_custom_ops { - cpu_custom_ops_set.insert(CArc::as_ptr(op) as u64); - } - for c in &callables { - callable_set.insert(CArc::as_ptr(&c.0) as u64); - for capture in c.0.captures.as_ref() { - if !captured_set.contains(capture) { - captured_set.insert(*capture); - captured.push(*capture); - } - } - for op in c.0.cpu_custom_ops.as_ref() { - let id = CArc::as_ptr(op) as u64; - if !cpu_custom_ops_set.contains(&id) { - cpu_custom_ops_set.insert(id); - cpu_custom_ops.push(op.clone()); - } - } - } - (resource_tracker, cpu_custom_ops, captured) - }) - } - fn build_callable(&mut self, body: impl FnOnce(&mut Self) -> R) -> RawCallable { - let ret = body(self); - let ret_type = ret._return(); - let (rt, cpu_custom_ops, captures) = self.collect_module_info(); - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!(r.lock); - r.lock = false; - if let Some(t) = &r.callable_ret_type { - assert!( - luisa_compute_ir::context::is_type_equal(t, &ret_type), - "Return type mismatch" - ); - } else { - r.callable_ret_type = Some(ret_type.clone()); - } - assert_eq!(r.scopes.len(), 1); - let scope = r.scopes.pop().unwrap(); - let entry = scope.finish(); - let ir_module = Module { - entry, - kind: ModuleKind::Kernel, - pools: r.pools.clone().unwrap(), - flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM - | ModuleFlags::REQUIRES_FWD_AD_TRANSFORM, - }; - let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); - let module = CallableModule { - module: ir_module, - ret_type, - cpu_custom_ops: CBoxedSlice::new(cpu_custom_ops), - captures: CBoxedSlice::new(captures), - args: CBoxedSlice::new(self.args.clone()), - pools: r.pools.clone().unwrap(), - }; - let module = CallableModuleRef(CArc::new(module)); - r.reset(); - RawCallable { - module, - resource_tracker: rt, - } - }) - } - fn build_kernel( - &mut self, - options: KernelBuildOptions, - body: impl FnOnce(&mut Self), - ) -> crate::runtime::RawKernel { - body(self); - let (rt, cpu_custom_ops, captures) = self.collect_module_info(); - RECORDER.with(|r| -> crate::runtime::RawKernel { - let mut r = r.borrow_mut(); - assert!(r.lock); - r.lock = false; - assert_eq!(r.scopes.len(), 1); - let scope = r.scopes.pop().unwrap(); - let entry = scope.finish(); - - let ir_module = Module { - entry, - kind: ModuleKind::Kernel, - pools: r.pools.clone().unwrap(), - flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM - | ModuleFlags::REQUIRES_FWD_AD_TRANSFORM, - }; - let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); - let module = KernelModule { - module: ir_module, - cpu_custom_ops: CBoxedSlice::new(cpu_custom_ops), - captures: CBoxedSlice::new(captures), - shared: CBoxedSlice::new(r.shared.clone()), - args: CBoxedSlice::new(self.args.clone()), - block_size: r.block_size.unwrap_or([64, 1, 1]), - pools: r.pools.clone().unwrap(), - }; - - let module = CArc::new(module); - let name = options.name.unwrap_or("".to_string()); - let name = Arc::new(CString::new(name).unwrap()); - let shader_options = api::ShaderOption { - enable_cache: options.enable_cache, - enable_fast_math: options.enable_fast_math, - enable_debug_info: options.enable_debug_info, - compile_only: false, - name: name.as_ptr(), - }; - let artifact = if options.async_compile { - ShaderArtifact::Async(AsyncShaderArtifact::new( - self.device.clone().unwrap(), - module.clone(), - shader_options, - name, - )) - } else { - ShaderArtifact::Sync( - self.device - .as_ref() - .unwrap() - .inner - .create_shader(&module, &shader_options), - ) - }; - // - r.reset(); - RawKernel { - artifact, - device: self.device.clone().unwrap(), - resource_tracker: rt, - module, - } - }) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct KernelBuildOptions { - pub enable_debug_info: bool, - pub enable_optimization: bool, - pub async_compile: bool, - pub enable_cache: bool, - pub enable_fast_math: bool, - pub name: Option, -} - -impl Default for KernelBuildOptions { - fn default() -> Self { - let enable_debug_info = match env::var("LUISA_DEBUG") { - Ok(s) => s == "1", - Err(_) => false, - }; - Self { - enable_debug_info, - enable_optimization: true, - async_compile: false, - enable_cache: true, - enable_fast_math: true, - name: None, - } - } -} - -pub trait KernelBuildFn { - fn build_kernel( - &self, - builder: &mut KernelBuilder, - options: KernelBuildOptions, - ) -> crate::runtime::RawKernel; -} - -pub trait CallableBuildFn { - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder) - -> RawCallable; -} - -pub trait StaticCallableBuildFn: CallableBuildFn {} - -// @FIXME: this looks redundant -pub unsafe trait CallableRet { - fn _return(&self) -> CArc; - fn _from_return(node: NodeRef) -> Self; -} - -unsafe impl CallableRet for () { - fn _return(&self) -> CArc { - Type::void() - } - fn _from_return(_: NodeRef) -> Self {} -} - -unsafe impl CallableRet for T { - fn _return(&self) -> CArc { - __current_scope(|b| { - b.return_(self.node()); - }); - T::Value::type_() - } - fn _from_return(node: NodeRef) -> Self { - Self::from_node(node) - } -} - -pub trait CallableSignature<'a> { - type Callable; - type DynCallable; - type Fn: CallableBuildFn; - type StaticFn: StaticCallableBuildFn; - type DynFn: CallableBuildFn + 'static; - type Ret: CallableRet; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable; - fn create_dyn_callable(device: Device, init_once: bool, f: Self::DynFn) -> Self::DynCallable; -} - -pub trait KernelSignature<'a> { - type Fn: KernelBuildFn; - type Kernel; - - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel; -} -macro_rules! impl_callable_signature { - ()=>{ - impl<'a, R: CallableRet +'static> CallableSignature<'a> for fn()->R { - type Fn = &'a dyn Fn() ->R; - type DynFn = BoxR>; - type StaticFn = fn() -> R; - type Callable = CallableR>; - type DynCallable = DynCallableR>; - type Ret = R; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ - Callable { - inner: callable, - _marker:std::marker::PhantomData, - } - } - fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { - DynCallable::new(device, init_once, Box::new(move |arg, builder| { - let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); - Self::wrap_raw_callable(raw_callable) - })) - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<'a, R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature<'a> for fn($first, $($rest,)*)->R { - type Fn = &'a dyn Fn($first, $($rest),*)->R; - type DynFn = BoxR>; - type Callable = CallableR>; - type StaticFn = fn($first, $($rest,)*)->R; - type DynCallable = DynCallableR>; - type Ret = R; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ - Callable { - inner: callable, - _marker:std::marker::PhantomData, - } - } - fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { - DynCallable::new(device, init_once, Box::new(move |arg, builder| { - let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); - Self::wrap_raw_callable(raw_callable) - })) - } - } - impl_callable_signature!($($rest)*); - }; -} -impl_callable_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_kernel_signature { - ()=>{ - impl<'a> KernelSignature<'a> for fn() { - type Fn = &'a dyn Fn(); - type Kernel = Kernel; - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { - Self::Kernel{ - inner:kernel, - _marker:std::marker::PhantomData, - } - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature<'a> for fn($first, $($rest,)*) { - type Fn = &'a dyn Fn($first::Parameter, $($rest::Parameter),*); - type Kernel = Kernel; - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { - Self::Kernel{ - inner:kernel, - _marker:std::marker::PhantomData, - } - } - } - impl_kernel_signature!($($rest)*); - }; -} -impl_kernel_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); - -macro_rules! impl_callable_build_for_fn { - ()=>{ - impl CallableBuildFn for &dyn Fn()->R { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() - }) - } - } - impl CallableBuildFn for fn()->R { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() - }) - } - } - impl CallableBuildFn for BoxR> { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() - }) - } - } - impl StaticCallableBuildFn for fn()->R {} - }; - ($first:ident $($rest:ident)*) => { - impl CallableBuildFn for &dyn Fn($first, $($rest,)*)->R { - #[allow(non_snake_case)] - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |builder| { - if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) - } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) - } - }) - } - } - impl CallableBuildFn for BoxR> { - #[allow(non_snake_case)] - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |builder| { - if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) - } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) - } - }) - } - } - impl CallableBuildFn for fn($first, $($rest,)*)->R { - #[allow(non_snake_case)] - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |builder| { - if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) - } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) - } - }) - } - } - impl StaticCallableBuildFn for fn($first, $($rest,)*)->R {} - impl_callable_build_for_fn!($($rest)*); - }; -} -impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_kernel_build_for_fn { - ()=>{ - impl KernelBuildFn for &dyn Fn() { - fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { - builder.build_kernel(options, |_| { - self() - }) - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for &dyn Fn($first, $($rest,)*) { - #[allow(non_snake_case)] - fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { - builder.build_kernel(options, |builder| { - let $first = $first::def_param(builder); - $(let $rest = $rest::def_param(builder);)* - self($first, $($rest,)*) - }) - } - } - impl_kernel_build_for_fn!($($rest)*); - }; -} -impl_kernel_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); - -pub fn if_then_else( - cond: impl Mask, - then: impl Fn() -> R, - else_: impl Fn() -> R, -) -> R { - let cond = cond.node(); - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - s.push(IrBuilder::new(pools)); - }); - let then = then(); - let then_block = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - let then_block = s.pop().unwrap().finish(); - s.push(IrBuilder::new(pools)); - then_block - }); - let else_ = else_(); - let else_block = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let s = &mut r.scopes; - s.pop().unwrap().finish() - }); - let then_nodes = then.to_vec_nodes(); - let else_nodes = else_.to_vec_nodes(); - __current_scope(|b| { - b.if_(cond, then_block, else_block); - }); - assert_eq!(then_nodes.len(), else_nodes.len()); - let phis = __current_scope(|b| { - then_nodes - .iter() - .zip(else_nodes.iter()) - .map(|(then, else_)| { - let incomings = vec![ - PhiIncoming { - value: *then, - block: then_block, - }, - PhiIncoming { - value: *else_, - block: else_block, - }, - ]; - assert_eq!(then.type_(), else_.type_()); - let phi = b.phi(&incomings, then.type_().clone()); - phi - }) - .collect::>() - }); - R::from_vec_nodes(phis) -} - -pub fn generic_loop(cond: impl Fn() -> Bool, body: impl Fn(), update: impl Fn()) { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - s.push(IrBuilder::new(pools)); - }); - let cond_v = cond().node(); - let prepare = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - let prepare = s.pop().unwrap().finish(); - s.push(IrBuilder::new(pools)); - prepare - }); - body(); - let body = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - let body = s.pop().unwrap().finish(); - s.push(IrBuilder::new(pools)); - body - }); - update(); - let update = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let s = &mut r.scopes; - s.pop().unwrap().finish() - }); - __current_scope(|b| { - b.generic_loop(prepare, cond_v, body, update); - }); -} - -pub struct SwitchBuilder { - cases: Vec<(i32, Pooled, Vec)>, - default: Option<(Pooled, Vec)>, - value: NodeRef, - _marker: PhantomData, - depth: usize, -} - -pub fn switch(node: Expr) -> SwitchBuilder { - SwitchBuilder::new(node) -} - -impl SwitchBuilder { - pub fn new(node: Expr) -> Self { - SwitchBuilder { - cases: vec![], - default: None, - value: node.node(), - _marker: PhantomData, - depth: RECORDER.with(|r| r.borrow().scopes.len()), - } - } - pub fn case(mut self, value: i32, then: impl Fn() -> R) -> Self { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - assert_eq!(s.len(), self.depth); - s.push(IrBuilder::new(pools)); - }); - let then = then(); - let block = __pop_scope(); - self.cases.push((value, block, then.to_vec_nodes())); - self - } - pub fn default(mut self, then: impl Fn() -> R) -> Self { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - assert_eq!(s.len(), self.depth); - s.push(IrBuilder::new(pools)); - }); - let then = then(); - let block = __pop_scope(); - self.default = Some((block, then.to_vec_nodes())); - self - } - pub fn finish(self) -> R { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let s = &mut r.scopes; - assert_eq!(s.len(), self.depth); - }); - let cases = self - .cases - .iter() - .map(|(v, b, _)| SwitchCase { - value: *v, - block: *b, - }) - .collect::>(); - let case_phis = self - .cases - .iter() - .map(|(_, _, nodes)| nodes.clone()) - .collect::>(); - let phi_count = case_phis[0].len(); - let mut default_nodes = vec![]; - let default_block = if self.default.is_none() { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - assert_eq!(s.len(), self.depth); - s.push(IrBuilder::new(pools)); - }); - for i in 0..phi_count { - let msg = CString::new("unreachable code in switch statement!").unwrap(); - let default_node = __current_scope(|b| { - b.call( - Func::Unreachable(CBoxedSlice::from(msg)), - &[], - case_phis[0][i].type_().clone(), - ) - }); - default_nodes.push(default_node); - } - __pop_scope() - } else { - default_nodes = self.default.as_ref().unwrap().1.clone(); - self.default.as_ref().unwrap().0 - }; - __current_scope(|b| { - b.switch(self.value, &cases, default_block); - }); - let mut phis = vec![]; - for i in 0..phi_count { - let mut incomings = vec![]; - for (j, nodes) in case_phis.iter().enumerate() { - incomings.push(PhiIncoming { - value: nodes[i], - block: self.cases[j].1, - }); - } - incomings.push(PhiIncoming { - value: default_nodes[i], - block: default_block, - }); - let phi = __current_scope(|b| b.phi(&incomings, case_phis[0][i].type_().clone())); - phis.push(phi); - } - R::from_vec_nodes(phis) - } -} - -#[macro_export] -/** - * If you want rustfmt to format your code, use if_!(cond, { .. }, { .. }) or if_!(cond, { .. }, else, {...}) - * instead of if_!(cond, { .. }, else {...}). - * - */ -macro_rules! if_ { - ($cond:expr, $then:block, else $else_:block) => { - $crate::lang::if_then_else($cond, || $then, || $else_) - }; - ($cond:expr, $then:block, else, $else_:block) => { - $crate::lang::if_then_else($cond, || $then, || $else_) - }; - ($cond:expr, $then:block, $else_:block) => { - $crate::lang::if_then_else($cond, || $then, || $else_) - }; - ($cond:expr, $then:block) => { - $crate::lang::if_then_else($cond, || $then, || {}) - }; -} -#[macro_export] -macro_rules! while_ { - ($cond:expr,$body:block) => { - $crate::lang::generic_loop(|| $cond, || $body, || {}) - }; -} -#[macro_export] -macro_rules! loop_ { - ($body:block) => { - $crate::while_!(const_(true), $body) - }; -} -pub trait ForLoopRange { - type Element: Value; - fn start(&self) -> NodeRef; - fn end(&self) -> NodeRef; - fn end_inclusive(&self) -> bool; -} -macro_rules! impl_range { - ($t:ty) => { - impl ForLoopRange for std::ops::RangeInclusive<$t> { - type Element = $t; - fn start(&self) -> NodeRef { - const_(*self.start()).node() - } - fn end(&self) -> NodeRef { - const_(*self.end()).node() - } - fn end_inclusive(&self) -> bool { - true - } - } - impl ForLoopRange for std::ops::RangeInclusive> { - type Element = $t; - fn start(&self) -> NodeRef { - self.start().node() - } - fn end(&self) -> NodeRef { - self.end().node() - } - fn end_inclusive(&self) -> bool { - true - } - } - impl ForLoopRange for std::ops::Range<$t> { - type Element = $t; - fn start(&self) -> NodeRef { - const_(self.start).node() - } - fn end(&self) -> NodeRef { - const_(self.end).node() - } - fn end_inclusive(&self) -> bool { - false - } - } - impl ForLoopRange for std::ops::Range> { - type Element = $t; - fn start(&self) -> NodeRef { - self.start.node() - } - fn end(&self) -> NodeRef { - self.end.node() - } - fn end_inclusive(&self) -> bool { - false - } - } - }; -} -impl_range!(i32); -impl_range!(i64); -impl_range!(u32); -impl_range!(u64); - -#[inline] -pub fn for_range(r: R, body: impl Fn(Expr)) { - let start = r.start(); - let end = r.end(); - let inc = |v: NodeRef| { - __current_scope(|b| { - let one = b.const_(Const::One(v.type_().clone())); - b.call(Func::Add, &[v, one], v.type_().clone()) - }) - }; - let i = __current_scope(|b| b.local(start)); - generic_loop( - || { - __current_scope(|b| { - let i = b.call(Func::Load, &[i], i.type_().clone()); - Bool::from_node(b.call( - if r.end_inclusive() { - Func::Le - } else { - Func::Lt - }, - &[i, end], - ::type_(), - )) - }) - }, - move || { - let i = __current_scope(|b| b.call(Func::Load, &[i], i.type_().clone())); - body(Expr::::from_node(i)); - }, - || { - let i_old = __current_scope(|b| b.call(Func::Load, &[i], i.type_().clone())); - let i_new = inc(i_old); - __current_scope(|b| b.update(i, i_new)); - }, - ) -} - -#[inline] -pub fn break_() { - __current_scope(|b| { - b.break_(); - }); -} - -#[inline] -pub fn continue_() { - __current_scope(|b| { - b.continue_(); - }); -} - -pub fn return_v(v: T) { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - if r.callable_ret_type.is_none() { - r.callable_ret_type = Some(v.node().type_().clone()); - } else { - assert!( - luisa_compute_ir::context::is_type_equal( - r.callable_ret_type.as_ref().unwrap(), - v.node().type_() - ), - "return type mismatch" - ); - } - }); - __current_scope(|b| { - b.return_(v.node()); - }); -} - -pub fn return_() { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - if r.callable_ret_type.is_none() { - r.callable_ret_type = Some(Type::void()); - } else { - assert!(luisa_compute_ir::context::is_type_equal( - r.callable_ret_type.as_ref().unwrap(), - &Type::void() - )); - } - }); - __current_scope(|b| { - b.return_(INVALID_REF); - }); -} - -struct AdContext { - started: bool, - backward_called: bool, - is_forward_mode: bool, - n_forward_grads: usize, - // forward: Option>, -} - -impl AdContext { - fn new_rev() -> Self { - Self { - started: false, - backward_called: false, - is_forward_mode: false, - n_forward_grads: 0, - } - } - fn new_fwd(n: usize) -> Self { - Self { - started: false, - backward_called: false, - is_forward_mode: true, - n_forward_grads: n, - } - } - fn reset(&mut self) { - self.started = false; - } -} -thread_local! { - static AD_CONTEXT:RefCell = RefCell::new(AdContext::new_rev()); -} -pub fn requires_grad(var: impl ExprProxy) { - AD_CONTEXT.with(|c| { - let c = c.borrow(); - assert!(c.started, "autodiff section is not started"); - assert!( - !c.is_forward_mode, - "requires_grad() is called in forward mode" - ); - assert!(!c.backward_called, "backward is already called"); - }); - __current_scope(|b| { - b.call(Func::RequiresGradient, &[var.node()], Type::void()); - }); -} - -pub fn backward(out: T) { - backward_with_grad( - out, - FromNode::from_node(__current_scope(|b| { - let one = new_node( - b.pools(), - Node::new( - CArc::new(Instruction::Const(Const::One(::type_()))), - ::type_(), - ), - ); - b.append(one); - one - })), - ); -} - -pub fn backward_with_grad(out: T, grad: T) { - AD_CONTEXT.with(|c| { - let mut c = c.borrow_mut(); - assert!(c.started, "autodiff section is not started"); - assert!(!c.is_forward_mode, "backward() is called in forward mode"); - assert!(!c.backward_called, "backward is already called"); - c.backward_called = true; - }); - let out = out.node(); - let grad = grad.node(); - __current_scope(|b| { - b.call(Func::GradientMarker, &[out, grad], Type::void()); - b.call(Func::Backward, &[], Type::void()); - }); -} - -/// Gradient of a value in *Reverse mode* AD -pub fn gradient(var: T) -> T { - AD_CONTEXT.with(|c| { - let c = c.borrow(); - assert!(c.started, "autodiff section is not started"); - assert!(!c.is_forward_mode, "gradient() is called in forward mode"); - assert!(c.backward_called, "backward is not called"); - }); - T::from_node(__current_scope(|b| { - b.call(Func::Gradient, &[var.node()], var.node().type_().clone()) - })) -} -/// Gradient of a value in *Reverse mode* AD -pub fn grad(var: T) -> T { - gradient(var) -} - -// pub fn detach(body: impl FnOnce() -> R) -> R { -// RECORDER.with(|r| { -// let mut r = r.borrow_mut(); -// let s = &mut r.scopes; -// s.push(IrBuilder::new()); -// }); -// let ret = body(); -// let fwd = pop_scope(); -// __current_scope(|b| { -// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)), Type::void())); -// b.append(node); -// }); -// let nodes = ret.to_vec_nodes(); -// let nodes: Vec<_> = nodes -// .iter() -// .map(|n| __current_scope(|b| b.call(Func::Detach, &[*n], n.type_()))) -// .collect(); -// R::from_vec_nodes(nodes) -// } -pub fn detach(v: T) -> T { - let v = v.node(); - let node = __current_scope(|b| b.call(Func::Detach, &[v], v.type_().clone())); - T::from_node(node) -} - -/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input variable -pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { - AD_CONTEXT.with(|c| { - let mut c = c.borrow_mut(); - assert!(!c.started, "autodiff section already started"); - *c = AdContext::new_fwd(n_grads); - c.started = true; - }); - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - s.push(IrBuilder::new(pools)); - }); - body(); - let n_grads = AD_CONTEXT.with(|c| { - let mut c = c.borrow_mut(); - let n_grads = c.n_forward_grads; - c.reset(); - n_grads - }); - let body = __pop_scope(); - __current_scope(|b| { - b.fwd_ad_scope(body, n_grads); - }); -} - -/// Propagate N gradients w.r.t to input variable using *Forward mode* AD -pub fn propagate_gradient(v: T, grads: &[T]) { - AD_CONTEXT.with(|c| { - let c = c.borrow(); - assert_eq!(grads.len(), c.n_forward_grads); - assert!(c.started, "autodiff section is not started"); - assert!( - c.is_forward_mode, - "propagate_gradient() is called in backward mode" - ); - }); - __current_scope(|b| { - let mut nodes = vec![v.node()]; - nodes.extend(grads.iter().map(|g| g.node())); - b.call(Func::PropagateGrad, &nodes, Type::void()); - }); -} - -pub fn output_gradients(v: T) -> Vec { - let n = AD_CONTEXT.with(|c| { - let c = c.borrow(); - assert!(c.started, "autodiff section is not started"); - assert!( - c.is_forward_mode, - "output_gradients() is called in backward mode" - ); - c.n_forward_grads - }); - __current_scope(|b| { - let mut grads = vec![]; - for i in 0..n { - let idx = b.const_(Const::Int32(i as i32)); - grads.push(T::from_node(b.call( - Func::OutputGrad, - &[v.node(), idx], - v.node().type_().clone(), - ))); - } - grads - }) -} - -pub fn autodiff(body: impl Fn()) { - AD_CONTEXT.with(|c| { - let mut c = c.borrow_mut(); - assert!(!c.started, "autodiff section is already started"); - *c = AdContext::new_rev(); - c.started = true; - }); - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - let pools = r.pools.clone().unwrap(); - let s = &mut r.scopes; - s.push(IrBuilder::new(pools)); - }); - body(); - AD_CONTEXT.with(|c| { - let mut c = c.borrow_mut(); - assert!(c.started, "autodiff section is not started"); - assert!(c.backward_called, "backward is not called"); - c.reset(); - }); - let body = __pop_scope(); - __current_scope(|b| { - b.ad_scope(body); - }); -} - -pub fn is_cpu_backend() -> bool { - RECORDER.with(|r| { - let r = r.borrow(); - if r.device.is_none() { - return false; - } - r.device - .as_ref() - .unwrap() - .upgrade() - .unwrap() - .inner - .query("device_name") - .map(|s| s == "cpu") - .unwrap_or(false) - }) -} - -pub fn __env_need_backtrace() -> bool { - match std::env::var("LUISA_BACKTRACE") { - Ok(s) => s == "1" || s == "ON", - Err(_) => false, - } -} - -pub fn __unreachable(file: &str, line: u32, col: u32) { - let path = std::path::Path::new(file); - let pretty_filename: String; - if path.exists() { - pretty_filename = std::fs::canonicalize(path) - .unwrap() - .to_str() - .unwrap() - .to_string(); - } else { - pretty_filename = file.to_string(); - } - let msg = if is_cpu_backend() && __env_need_backtrace() { - let backtrace = get_backtrace(); - format!( - "unreachable code at {}:{}:{} \nbacktrace: {}", - pretty_filename, line, col, backtrace - ) - } else { - format!( - "unreachable code at {}:{}:{} \n", - pretty_filename, line, col - ) - }; - __current_scope(|b| { - b.call( - Func::Unreachable(CBoxedSlice::new( - CString::new(msg).unwrap().into_bytes_with_nul(), - )), - &[], - Type::void(), - ); - }); -} - -#[inline] -pub fn __assert(cond: impl Into>, msg: &str, file: &str, line: u32, col: u32) { - let cond = cond.into(); - let path = std::path::Path::new(file); - let pretty_filename: String; - if path.exists() { - pretty_filename = std::fs::canonicalize(path) - .unwrap() - .to_str() - .unwrap() - .to_string(); - } else { - pretty_filename = file.to_string(); - } - let msg = if is_cpu_backend() && __env_need_backtrace() { - let backtrace = get_backtrace(); - format!( - "assertion failed: {} at {}:{}:{} \nbacktrace: {}", - msg, pretty_filename, line, col, backtrace - ) - } else { - format!( - "assertion failed: {} at {}:{}:{} \n", - msg, pretty_filename, line, col - ) - }; - __current_scope(|b| { - b.call( - Func::Assert(CBoxedSlice::new( - CString::new(msg).unwrap().into_bytes_with_nul(), - )), - &[cond.node()], - Type::void(), - ); - }); -} - -pub(crate) fn need_runtime_check() -> bool { - cfg!(debug_assertions) - || match env::var("LUISA_DEBUG") { - Ok(s) => s == "full" || s == "1", - Err(_) => false, - } - || __env_need_backtrace() -} diff --git a/luisa_compute/src/lang/traits.rs b/luisa_compute/src/lang/ops.rs similarity index 56% rename from luisa_compute/src/lang/traits.rs rename to luisa_compute/src/lang/ops.rs index 8c02e94..4920bf7 100644 --- a/luisa_compute/src/lang/traits.rs +++ b/luisa_compute/src/lang/ops.rs @@ -1,12 +1,8 @@ -use crate::prelude::*; -use crate::*; -use luisa_compute_ir::ir::new_user_node; -use luisa_compute_ir::CArc; -use luisa_compute_ir::{ir::Func, ir::Type, TypeOf}; -use std::cell::{Cell, RefCell}; +use crate::internal_prelude::*; use std::ops::*; -use super::Expr; +pub mod impls; + pub trait VarTrait: Copy + Clone + 'static + FromNode { type Value: Value; type Short: VarTrait; @@ -23,53 +19,7 @@ pub trait VarTrait: Copy + Clone + 'static + FromNode { ::type_() } } -macro_rules! impl_var_trait { - ($t:ty) => { - impl VarTrait for PrimExpr<$t> { - type Value = $t; - type Short = Expr; - type Ushort = Expr; - type Int = Expr; - type Uint = Expr; - type Long = Expr; - type Ulong = Expr; - type Half = Expr; - type Float = Expr; - type Double = Expr; - type Bool = Expr; - } - impl ScalarVarTrait for PrimExpr<$t> {} - impl ScalarOrVector for PrimExpr<$t> { - type Element = PrimExpr<$t>; - type ElementHost = $t; - } - impl BuiltinVarTrait for PrimExpr<$t> {} - }; -} -impl_var_trait!(f16); -impl_var_trait!(f32); -impl_var_trait!(f64); -impl_var_trait!(i16); -impl_var_trait!(u16); -impl_var_trait!(i32); -impl_var_trait!(u32); -impl_var_trait!(i64); -impl_var_trait!(u64); -impl_var_trait!(bool); -impl FromNode for PrimExpr { - fn from_node(node: NodeRef) -> Self { - Self { - node, - _phantom: std::marker::PhantomData, - } - } -} -impl ToNode for PrimExpr { - fn node(&self) -> NodeRef { - self.node - } -} fn _cast(expr: T) -> U { let node = expr.node(); __current_scope(|s| { @@ -529,345 +479,7 @@ pub trait FloatVarTrait: (self.sin(), self.cos()) } } -macro_rules! impl_binop { - ($t:ty, $proxy:ty, $tr_assign:ident, $method_assign:ident, $tr:ident, $method:ident) => { - impl $tr_assign> for $proxy { - fn $method_assign(&mut self, rhs: Expr<$t>) { - *self = self.clone().$method(rhs); - } - } - impl $tr_assign<$t> for $proxy { - fn $method_assign(&mut self, rhs: $t) { - *self = self.clone().$method(rhs); - } - } - impl $tr> for $proxy { - type Output = Expr<$t>; - fn $method(self, rhs: Expr<$t>) -> Self::Output { - __current_scope(|s| { - let lhs = ToNode::node(&self); - let rhs = ToNode::node(&rhs); - let ret = s.call(Func::$tr, &[lhs, rhs], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - - impl $tr<$t> for $proxy { - type Output = Expr<$t>; - fn $method(self, rhs: $t) -> Self::Output { - $tr::$method(self, const_(rhs)) - } - } - impl $tr<$proxy> for $t { - type Output = Expr<$t>; - fn $method(self, rhs: $proxy) -> Self::Output { - $tr::$method(const_(self), rhs) - } - } - }; -} -macro_rules! impl_common_binop { - ($t:ty,$proxy:ty) => { - impl_binop!($t, $proxy, AddAssign, add_assign, Add, add); - impl_binop!($t, $proxy, SubAssign, sub_assign, Sub, sub); - impl_binop!($t, $proxy, MulAssign, mul_assign, Mul, mul); - impl_binop!($t, $proxy, DivAssign, div_assign, Div, div); - impl_binop!($t, $proxy, RemAssign, rem_assign, Rem, rem); - }; -} -macro_rules! impl_int_binop { - ($t:ty,$proxy:ty) => { - impl_binop!($t, $proxy, ShlAssign, shl_assign, Shl, shl); - impl_binop!($t, $proxy, ShrAssign, shr_assign, Shr, shr); - impl_binop!($t, $proxy, BitAndAssign, bitand_assign, BitAnd, bitand); - impl_binop!($t, $proxy, BitOrAssign, bitor_assign, BitOr, bitor); - impl_binop!($t, $proxy, BitXorAssign, bitxor_assign, BitXor, bitxor); - }; -} - -macro_rules! impl_not { - ($t:ty,$proxy:ty) => { - impl Not for $proxy { - type Output = Expr<$t>; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - }; -} -macro_rules! impl_neg { - ($t:ty,$proxy:ty) => { - impl Neg for $proxy { - type Output = Expr<$t>; - fn neg(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - }; -} -macro_rules! impl_fneg { - ($t:ty, $proxy:ty) => { - impl Neg for $proxy { - type Output = Expr<$t>; - fn neg(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - }; -} -impl Not for PrimExpr { - type Output = Expr; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - FromNode::from_node(ret) - }) - } -} -impl_common_binop!(f16, PrimExpr); -impl_common_binop!(f32, PrimExpr); -impl_common_binop!(f64, PrimExpr); -impl_common_binop!(i16, PrimExpr); -impl_common_binop!(i32, PrimExpr); -impl_common_binop!(i64, PrimExpr); -impl_common_binop!(u16, PrimExpr); -impl_common_binop!(u32, PrimExpr); -impl_common_binop!(u64, PrimExpr); - -impl_binop!( - bool, - PrimExpr, - BitAndAssign, - bitand_assign, - BitAnd, - bitand -); -impl_binop!( - bool, - PrimExpr, - BitOrAssign, - bitor_assign, - BitOr, - bitor -); -impl_binop!( - bool, - PrimExpr, - BitXorAssign, - bitxor_assign, - BitXor, - bitxor -); -impl_int_binop!(i16, PrimExpr); -impl_int_binop!(i32, PrimExpr); -impl_int_binop!(i64, PrimExpr); -impl_int_binop!(u16, PrimExpr); -impl_int_binop!(u32, PrimExpr); -impl_int_binop!(u64, PrimExpr); - -impl_not!(i16, PrimExpr); -impl_not!(i32, PrimExpr); -impl_not!(i64, PrimExpr); -impl_not!(u16, PrimExpr); -impl_not!(u32, PrimExpr); -impl_not!(u64, PrimExpr); - -impl_neg!(i16, PrimExpr); -impl_neg!(i32, PrimExpr); -impl_neg!(i64, PrimExpr); -impl_neg!(u16, PrimExpr); -impl_neg!(u32, PrimExpr); -impl_neg!(u64, PrimExpr); - -impl_fneg!(f16, PrimExpr); -impl_fneg!(f32, PrimExpr); -impl_fneg!(f64, PrimExpr); - -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} -impl VarCmpEq for PrimExpr {} - -impl VarCmpEq for PrimExpr {} - -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} -impl VarCmp for PrimExpr {} - -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} -impl CommonVarOp for PrimExpr {} - -impl CommonVarOp for PrimExpr {} - -impl From for Float { - fn from(x: f64) -> Self { - (x as f32).into() - } -} -impl From for Double { - fn from(x: f32) -> Self { - (x as f64).into() - } -} -impl From for Half { - fn from(x: f64) -> Self { - f16::from_f64(x).into() - } -} -impl From for Half { - fn from(x: f32) -> Self { - f16::from_f32(x).into() - } -} - -impl FloatVarTrait for PrimExpr {} -impl FloatVarTrait for PrimExpr {} -impl FloatVarTrait for PrimExpr {} - -impl IntVarTrait for PrimExpr {} -impl IntVarTrait for PrimExpr {} -impl IntVarTrait for PrimExpr {} -impl IntVarTrait for PrimExpr {} -impl IntVarTrait for PrimExpr {} -impl IntVarTrait for PrimExpr {} - -macro_rules! impl_from { - ($from:ty, $to:ty) => { - impl From<$from> for PrimExpr<$to> { - fn from(x: $from) -> Self { - const_::<$to>(x.try_into().unwrap()) - } - } - }; -} - -impl_from!(i16, u16); -impl_from!(i16, i32); -impl_from!(i16, u32); -impl_from!(i16, i64); -impl_from!(i16, u64); - -impl_from!(u16, i16); -impl_from!(u16, i32); -impl_from!(u16, u32); -impl_from!(u16, i64); -impl_from!(u16, u64); - -impl_from!(i32, u16); -impl_from!(i32, i16); -impl_from!(i32, u32); -impl_from!(i32, i64); -impl_from!(i32, u64); - -impl_from!(i64, u16); -impl_from!(i64, i16); -impl_from!(i64, u64); -impl_from!(i64, i32); -impl_from!(i64, u32); - -impl_from!(u32, u16); -impl_from!(u32, i16); -impl_from!(u32, i32); -impl_from!(u32, i64); -impl_from!(u32, u64); - -impl_from!(u64, u16); -impl_from!(u64, i16); -impl_from!(u64, i64); -impl_from!(u64, i32); -impl_from!(u64, u32); - -impl Aggregate for Vec { - fn to_nodes(&self, nodes: &mut Vec) { - let len_node = new_user_node(__module_pools(), nodes.len()); - nodes.push(len_node); - for item in self { - item.to_nodes(nodes); - } - } - - fn from_nodes>(iter: &mut I) -> Self { - let len_node = iter.next().unwrap(); - let len = len_node.unwrap_user_data::(); - let mut ret = Vec::with_capacity(*len); - for _ in 0..*len { - ret.push(T::from_nodes(iter)); - } - ret - } -} - -impl Aggregate for RefCell { - fn to_nodes(&self, nodes: &mut Vec) { - self.borrow().to_nodes(nodes); - } - - fn from_nodes>(iter: &mut I) -> Self { - RefCell::new(T::from_nodes(iter)) - } -} -impl Aggregate for Cell { - fn to_nodes(&self, nodes: &mut Vec) { - self.get().to_nodes(nodes); - } - - fn from_nodes>(iter: &mut I) -> Self { - Cell::new(T::from_nodes(iter)) - } -} -impl Aggregate for Option { - fn to_nodes(&self, nodes: &mut Vec) { - match self { - Some(x) => { - let node = new_user_node(__module_pools(), 1); - nodes.push(node); - x.to_nodes(nodes); - } - None => { - let node = new_user_node(__module_pools(), 0); - nodes.push(node); - } - } - } - fn from_nodes>(iter: &mut I) -> Self { - let node = iter.next().unwrap(); - let tag = node.unwrap_user_data::(); - match *tag { - 0 => None, - 1 => Some(T::from_nodes(iter)), - _ => unreachable!(), - } - } -} pub trait ScalarVarTrait: ToNode + FromNode {} pub trait VectorVarTrait: ToNode + FromNode {} pub trait MatrixVarTrait: ToNode + FromNode {} @@ -876,6 +488,3 @@ pub trait ScalarOrVector: ToNode + FromNode { type ElementHost: Value; } pub trait BuiltinVarTrait: ToNode + FromNode {} -pub trait Int32 {} -impl Int32 for i32 {} -impl Int32 for u32 {} \ No newline at end of file diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs new file mode 100644 index 0000000..db3af67 --- /dev/null +++ b/luisa_compute/src/lang/ops/impls.rs @@ -0,0 +1,338 @@ +use super::*; +use crate::lang::types::core::*; +use crate::lang::types::VarDerefProxy; + +macro_rules! impl_var_trait { + ($t:ty) => { + impl VarTrait for prim::Expr<$t> { + type Value = $t; + type Short = prim::Expr; + type Ushort = prim::Expr; + type Int = prim::Expr; + type Uint = prim::Expr; + type Long = prim::Expr; + type Ulong = prim::Expr; + type Half = prim::Expr; + type Float = prim::Expr; + type Double = prim::Expr; + type Bool = prim::Expr; + } + impl ScalarVarTrait for prim::Expr<$t> {} + impl ScalarOrVector for prim::Expr<$t> { + type Element = prim::Expr<$t>; + type ElementHost = $t; + } + impl BuiltinVarTrait for prim::Expr<$t> {} + }; +} +impl_var_trait!(f16); +impl_var_trait!(f32); +impl_var_trait!(f64); +impl_var_trait!(i16); +impl_var_trait!(u16); +impl_var_trait!(i32); +impl_var_trait!(u32); +impl_var_trait!(i64); +impl_var_trait!(u64); +impl_var_trait!(bool); + +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} +impl VarCmpEq for prim::Expr {} + +impl VarCmpEq for prim::Expr {} + +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} +impl VarCmp for prim::Expr {} + +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} +impl CommonVarOp for prim::Expr {} + +impl CommonVarOp for prim::Expr {} + +impl FloatVarTrait for prim::Expr {} +impl FloatVarTrait for prim::Expr {} +impl FloatVarTrait for prim::Expr {} + +impl IntVarTrait for prim::Expr {} +impl IntVarTrait for prim::Expr {} +impl IntVarTrait for prim::Expr {} +impl IntVarTrait for prim::Expr {} +impl IntVarTrait for prim::Expr {} +impl IntVarTrait for prim::Expr {} + +macro_rules! impl_from { + ($from:ty, $to:ty) => { + impl From<$from> for prim::Expr<$to> { + fn from(x: $from) -> Self { + let y: $to = (x.try_into().unwrap()); + y.expr() + } + } + }; +} + +impl_from!(i16, u16); +impl_from!(i16, i32); +impl_from!(i16, u32); +impl_from!(i16, i64); +impl_from!(i16, u64); + +impl_from!(u16, i16); +impl_from!(u16, i32); +impl_from!(u16, u32); +impl_from!(u16, i64); +impl_from!(u16, u64); + +impl_from!(i32, u16); +impl_from!(i32, i16); +impl_from!(i32, u32); +impl_from!(i32, i64); +impl_from!(i32, u64); + +impl_from!(i64, u16); +impl_from!(i64, i16); +impl_from!(i64, u64); +impl_from!(i64, i32); +impl_from!(i64, u32); + +impl_from!(u32, u16); +impl_from!(u32, i16); +impl_from!(u32, i32); +impl_from!(u32, i64); +impl_from!(u32, u64); + +impl_from!(u64, u16); +impl_from!(u64, i16); +impl_from!(u64, i64); +impl_from!(u64, i32); +impl_from!(u64, u32); + +impl From for prim::Expr { + fn from(x: f64) -> Self { + (x as f32).into() + } +} +impl From for prim::Expr { + fn from(x: f32) -> Self { + (x as f64).into() + } +} +impl From for prim::Expr { + fn from(x: f64) -> Self { + f16::from_f64(x).into() + } +} +impl From for prim::Expr { + fn from(x: f32) -> Self { + f16::from_f32(x).into() + } +} + +macro_rules! impl_binop { + ($t:ty, $proxy:ty, $tr_assign:ident, $method_assign:ident, $tr:ident, $method:ident) => { + impl $tr_assign> for $proxy { + fn $method_assign(&mut self, rhs: prim::Expr<$t>) { + *self = self.clone().$method(rhs); + } + } + impl $tr_assign<$t> for $proxy { + fn $method_assign(&mut self, rhs: $t) { + *self = self.clone().$method(rhs); + } + } + impl $tr> for $proxy { + type Output = prim::Expr<$t>; + fn $method(self, rhs: prim::Expr<$t>) -> Self::Output { + __current_scope(|s| { + let lhs = ToNode::node(&self); + let rhs = ToNode::node(&rhs); + let ret = s.call(Func::$tr, &[lhs, rhs], Self::Output::type_()); + Expr::<$t>::from_node(ret) + }) + } + } + + impl $tr<$t> for $proxy { + type Output = prim::Expr<$t>; + fn $method(self, rhs: $t) -> Self::Output { + $tr::$method(self, rhs.expr()) + } + } + impl $tr<$proxy> for $t { + type Output = prim::Expr<$t>; + fn $method(self, rhs: $proxy) -> Self::Output { + $tr::$method(self.expr(), rhs) + } + } + }; +} +macro_rules! impl_common_binop { + ($t:ty,$proxy:ty) => { + impl_binop!($t, $proxy, AddAssign, add_assign, Add, add); + impl_binop!($t, $proxy, SubAssign, sub_assign, Sub, sub); + impl_binop!($t, $proxy, MulAssign, mul_assign, Mul, mul); + impl_binop!($t, $proxy, DivAssign, div_assign, Div, div); + impl_binop!($t, $proxy, RemAssign, rem_assign, Rem, rem); + }; +} +macro_rules! impl_int_binop { + ($t:ty,$proxy:ty) => { + impl_binop!($t, $proxy, ShlAssign, shl_assign, Shl, shl); + impl_binop!($t, $proxy, ShrAssign, shr_assign, Shr, shr); + impl_binop!($t, $proxy, BitAndAssign, bitand_assign, BitAnd, bitand); + impl_binop!($t, $proxy, BitOrAssign, bitor_assign, BitOr, bitor); + impl_binop!($t, $proxy, BitXorAssign, bitxor_assign, BitXor, bitxor); + }; +} + +macro_rules! impl_not { + ($t:ty,$proxy:ty) => { + impl Not for $proxy { + type Output = prim::Expr<$t>; + fn not(self) -> Self::Output { + __current_scope(|s| { + let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); + Expr::<$t>::from_node(ret) + }) + } + } + }; +} +macro_rules! impl_neg { + ($t:ty,$proxy:ty) => { + impl Neg for $proxy { + type Output = prim::Expr<$t>; + fn neg(self) -> Self::Output { + __current_scope(|s| { + let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); + Expr::<$t>::from_node(ret) + }) + } + } + }; +} +macro_rules! impl_fneg { + ($t:ty, $proxy:ty) => { + impl Neg for $proxy { + type Output = prim::Expr<$t>; + fn neg(self) -> Self::Output { + __current_scope(|s| { + let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); + Expr::<$t>::from_node(ret) + }) + } + } + }; +} +impl Not for prim::Expr { + type Output = prim::Expr; + fn not(self) -> Self::Output { + __current_scope(|s| { + let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); + FromNode::from_node(ret) + }) + } +} +impl_common_binop!(f16, prim::Expr); +impl_common_binop!(f32, prim::Expr); +impl_common_binop!(f64, prim::Expr); +impl_common_binop!(i16, prim::Expr); +impl_common_binop!(i32, prim::Expr); +impl_common_binop!(i64, prim::Expr); +impl_common_binop!(u16, prim::Expr); +impl_common_binop!(u32, prim::Expr); +impl_common_binop!(u64, prim::Expr); + +impl_binop!( + bool, + prim::Expr, + BitAndAssign, + bitand_assign, + BitAnd, + bitand +); +impl_binop!( + bool, + prim::Expr, + BitOrAssign, + bitor_assign, + BitOr, + bitor +); +impl_binop!( + bool, + prim::Expr, + BitXorAssign, + bitxor_assign, + BitXor, + bitxor +); +impl_int_binop!(i16, prim::Expr); +impl_int_binop!(i32, prim::Expr); +impl_int_binop!(i64, prim::Expr); +impl_int_binop!(u16, prim::Expr); +impl_int_binop!(u32, prim::Expr); +impl_int_binop!(u64, prim::Expr); + +impl_not!(i16, prim::Expr); +impl_not!(i32, prim::Expr); +impl_not!(i64, prim::Expr); +impl_not!(u16, prim::Expr); +impl_not!(u32, prim::Expr); +impl_not!(u64, prim::Expr); + +impl_neg!(i16, prim::Expr); +impl_neg!(i32, prim::Expr); +impl_neg!(i64, prim::Expr); +impl_neg!(u16, prim::Expr); +impl_neg!(u32, prim::Expr); +impl_neg!(u64, prim::Expr); + +impl_fneg!(f16, prim::Expr); +impl_fneg!(f32, prim::Expr); +impl_fneg!(f64, prim::Expr); + +macro_rules! impl_assign_ops { + ($ass:ident, $ass_m:ident, $o:ident, $o_m:ident) => { + impl std::ops::$ass for VarDerefProxy + where + P: VarProxy, + Expr: std::ops::$o>, + { + fn $ass_m(&mut self, rhs: Rhs) { + *self.deref_mut() = std::ops::$o::$o_m(**self, rhs); + } + } + }; +} +impl_assign_ops!(AddAssign, add_assign, Add, add); +impl_assign_ops!(SubAssign, sub_assign, Sub, sub); +impl_assign_ops!(MulAssign, mul_assign, Mul, mul); +impl_assign_ops!(DivAssign, div_assign, Div, div); +impl_assign_ops!(RemAssign, rem_assign, Rem, rem); +impl_assign_ops!(BitAndAssign, bitand_assign, BitAnd, bitand); +impl_assign_ops!(BitOrAssign, bitor_assign, BitOr, bitor); +impl_assign_ops!(BitXorAssign, bitxor_assign, BitXor, bitxor); +impl_assign_ops!(ShlAssign, shl_assign, Shl, shl); +impl_assign_ops!(ShrAssign, shr_assign, Shr, shr); diff --git a/luisa_compute/src/lang/poly.rs b/luisa_compute/src/lang/poly.rs index a45611f..ef34531 100644 --- a/luisa_compute/src/lang/poly.rs +++ b/luisa_compute/src/lang/poly.rs @@ -1,30 +1,28 @@ -use std::{ - any::{Any, TypeId}, - collections::HashMap, - fmt::Debug, - hash::Hash, -}; +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; -use crate::*; -use crate::{resource::Buffer, Device}; -use luisa_compute_derive::__Value; +use crate::internal_prelude::*; -use super::{switch, traits::CommonVarOp, Aggregate, Uint, Value}; +use crate::lang::control_flow::switch; + +pub use crate::impl_polymorphic; pub struct PolyArray { tag: i32, key: K, - get: Box Box>, - _marker: std::marker::PhantomData, + get: Box) -> Box>, + _marker: PhantomData, } impl PolyArray { - pub fn new(tag: i32, key: K, get: Box Box>) -> Self { + pub fn new(tag: i32, key: K, get: Box) -> Box>) -> Self { Self { tag, get, key, - _marker: std::marker::PhantomData, + _marker: PhantomData, } } pub fn tag(&self) -> i32 { @@ -49,14 +47,14 @@ macro_rules! impl_new_poly_array { #[macro_export] macro_rules! impl_polymorphic { ($trait_:ident, $ty:ty) => { - impl PolymorphicImpl for $ty { + impl luisa_compute::lang::poly::PolymorphicImpl for $ty { fn new_poly_array( - buffer: &luisa_compute::Buffer, + buffer: &luisa_compute::resource::Buffer, tag: i32, key: K, - ) -> luisa_compute::PolyArray { + ) -> luisa_compute::lang::poly::PolyArray { let buffer = unsafe { buffer.shallow_clone() }; - luisa_compute::PolyArray::new( + luisa_compute::lang::poly::PolyArray::new( tag, key, Box::new(move |_, index| Box::new(buffer.var().read(index))), @@ -80,7 +78,7 @@ impl PolyVec { } } -#[derive(Clone, Copy, Debug, __Value, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Copy, Debug, Value, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(C)] pub struct TagIndex { pub tag: u32, @@ -209,7 +207,7 @@ impl PolymorphicBui Thus a tag can be retrieved by either (Key, TypeId) or Key alone if the key is unique for all types. */ pub struct Polymorphic { - _marker: std::marker::PhantomData, + _marker: PhantomData, arrays: Vec>, key_typeid_to_tag: HashMap<(DevirtualizationKey, TypeId), u32>, key_to_tag: HashMap>, @@ -279,7 +277,7 @@ impl<'a, K: Eq + Clone + Hash, T: ?Sized + 'static> PolymorphicRef<'a, K, T> { impl Polymorphic { pub fn new() -> Self { Self { - _marker: std::marker::PhantomData, + _marker: PhantomData, arrays: vec![], key_to_tag: HashMap::new(), key_typeid_to_tag: HashMap::new(), diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs new file mode 100644 index 0000000..dc8f3f1 --- /dev/null +++ b/luisa_compute/src/lang/types.rs @@ -0,0 +1,168 @@ +use std::any::Any; +use std::cell::Cell; +use std::ops::{Deref, DerefMut}; + +use crate::internal_prelude::*; + +pub mod array; +pub mod core; +pub mod dynamic; +pub mod shared; +pub mod vector; + +pub type Expr = ::Expr; +pub type Var = ::Var; + +pub trait Value: Copy + ir::TypeOf + 'static { + type Expr: ExprProxy; + type Var: VarProxy; + fn fields() -> Vec; + fn expr(self) -> Self::Expr { + const_(self) + } + fn var(self) -> Self::Var { + local::(self.expr()) + } +} + +pub trait ExprProxy: Copy + Aggregate + FromNode { + type Value: Value; + + fn var(self) -> Var { + def(self) + } + + fn zeroed() -> Self { + zeroed::() + } +} + +pub trait VarProxy: Copy + Aggregate + FromNode { + type Value: Value; + fn store>>(&self, value: U) { + let value = value.into(); + super::_store(self, &value); + } + fn load(&self) -> Expr { + __current_scope(|b| { + let nodes = self.to_vec_nodes(); + let mut ret = vec![]; + for node in nodes { + ret.push(b.call(Func::Load, &[node], node.type_().clone())); + } + Expr::::from_nodes(&mut ret.into_iter()) + }) + } + fn get_mut(&self) -> VarDerefProxy { + VarDerefProxy { + var: *self, + dirty: Cell::new(false), + assigned: self.load(), + _phantom: PhantomData, + } + } + fn _deref<'a>(&'a self) -> &'a Expr { + RECORDER.with(|r| { + let v: Expr = self.load(); + let r = r.borrow(); + let v: &Expr = r.arena.alloc(v); + unsafe { + let v: &'a Expr = std::mem::transmute(v); + v + } + }) + } + fn zeroed() -> Self { + local_zeroed::() + } +} + +pub struct VarDerefProxy +where + P: VarProxy, +{ + pub(crate) var: P, + pub(crate) dirty: Cell, + pub(crate) assigned: Expr, + pub(crate) _phantom: PhantomData, +} + +impl Deref for VarDerefProxy +where + P: VarProxy, +{ + type Target = Expr; + fn deref(&self) -> &Self::Target { + &self.assigned + } +} + +impl DerefMut for VarDerefProxy +where + P: VarProxy, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.dirty.set(true); + &mut self.assigned + } +} + +impl Drop for VarDerefProxy +where + P: VarProxy, +{ + fn drop(&mut self) { + if self.dirty.get() { + self.var.store(self.assigned) + } + } +} + +fn def, T: Value>(init: E) -> Var { + Var::::from_node(__current_scope(|b| b.local(init.node()))) +} +fn local(init: Expr) -> Var { + Var::::from_node(__current_scope(|b| b.local(init.node()))) +} + +fn local_zeroed() -> Var { + Var::::from_node(__current_scope(|b| { + b.local_zero_init(::type_()) + })) +} + +fn zeroed() -> T::Expr { + FromNode::from_node(__current_scope(|b| b.zero_initializer(T::type_()))) +} + +fn const_(value: T) -> T::Expr { + let node = __current_scope(|s| -> NodeRef { + let any = &value as &dyn Any; + if let Some(value) = any.downcast_ref::() { + s.const_(Const::Bool(*value)) + } else if let Some(value) = any.downcast_ref::() { + s.const_(Const::Int32(*value)) + } else if let Some(value) = any.downcast_ref::() { + s.const_(Const::Uint32(*value)) + } else if let Some(value) = any.downcast_ref::() { + s.const_(Const::Int64(*value)) + } else if let Some(value) = any.downcast_ref::() { + s.const_(Const::Uint64(*value)) + } else if let Some(value) = any.downcast_ref::() { + s.const_(Const::Float32(*value)) + } else if let Some(value) = any.downcast_ref::() { + s.const_(Const::Float64(*value)) + } else { + let mut buf = vec![0u8; std::mem::size_of::()]; + unsafe { + std::ptr::copy_nonoverlapping( + &value as *const T as *const u8, + buf.as_mut_ptr(), + buf.len(), + ); + } + s.const_(Const::Generic(CBoxedSlice::new(buf), T::type_())) + } + }); + FromNode::from_node(node) +} diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs new file mode 100644 index 0000000..e367966 --- /dev/null +++ b/luisa_compute/src/lang/types/array.rs @@ -0,0 +1,294 @@ +use super::*; +use crate::lang::index::IntoIndex; +use ir::ArrayType; + +#[derive(Clone, Copy, Debug)] +pub struct ArrayExpr { + marker: PhantomData, + node: NodeRef, +} + +#[derive(Clone, Copy, Debug)] +pub struct ArrayVar { + marker: PhantomData, + node: NodeRef, +} + +impl FromNode for ArrayExpr { + fn from_node(node: NodeRef) -> Self { + Self { + marker: PhantomData, + node, + } + } +} + +impl ToNode for ArrayExpr { + fn node(&self) -> NodeRef { + self.node + } +} + +impl Aggregate for ArrayExpr { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self::from_node(iter.next().unwrap()) + } +} + +impl FromNode for ArrayVar { + fn from_node(node: NodeRef) -> Self { + Self { + marker: PhantomData, + node, + } + } +} + +impl ToNode for ArrayVar { + fn node(&self) -> NodeRef { + self.node + } +} + +impl Aggregate for ArrayVar { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self::from_node(iter.next().unwrap()) + } +} + +impl ExprProxy for ArrayExpr { + type Value = [T; N]; +} + +impl VarProxy for ArrayVar { + type Value = [T; N]; +} + +impl ArrayVar { + pub fn len(&self) -> Expr { + (N as u32).expr() + } +} + +impl ArrayExpr { + pub fn zero() -> Self { + let node = __current_scope(|b| b.call(Func::ZeroInitializer, &[], <[T; N]>::type_())); + Self::from_node(node) + } + pub fn len(&self) -> Expr { + (N as u32).expr() + } +} + +impl IndexRead for ArrayExpr { + type Element = T; + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + + lc_assert!(i.cmplt((N as u64).expr())); + + Expr::::from_node(__current_scope(|b| { + b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) + })) + } +} + +impl IndexRead for ArrayVar { + type Element = T; + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + if need_runtime_check() { + lc_assert!(i.cmplt((N as u64).expr())); + } + + Expr::::from_node(__current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.call(Func::Load, &[gep], T::type_()) + })) + } +} + +impl IndexWrite for ArrayVar { + fn write>>(&self, i: I, value: V) { + let i = i.to_u64(); + let value = value.into(); + + if need_runtime_check() { + lc_assert!(i.cmplt((N as u64).expr())); + } + + __current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.update(gep, value.node()); + }); + } +} + +impl Value for [T; N] { + type Expr = ArrayExpr; + type Var = ArrayVar; + fn fields() -> Vec { + todo!("why this method exists?") + } +} + +#[derive(Clone, Copy, Debug)] +pub struct VLArrayExpr { + marker: PhantomData, + pub(super) node: NodeRef, +} + +impl FromNode for VLArrayExpr { + fn from_node(node: NodeRef) -> Self { + Self { + marker: PhantomData, + node, + } + } +} + +impl ToNode for VLArrayExpr { + fn node(&self) -> NodeRef { + self.node + } +} + +impl Aggregate for VLArrayExpr { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self::from_node(iter.next().unwrap()) + } +} + +#[derive(Clone, Copy, Debug)] +pub struct VLArrayVar { + marker: PhantomData, + node: NodeRef, +} + +impl FromNode for VLArrayVar { + fn from_node(node: NodeRef) -> Self { + Self { + marker: PhantomData, + node, + } + } +} + +impl ToNode for VLArrayVar { + fn node(&self) -> NodeRef { + self.node + } +} + +impl Aggregate for VLArrayVar { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self::from_node(iter.next().unwrap()) + } +} + +impl VLArrayVar { + pub fn read>>(&self, i: I) -> Expr { + let i = i.into(); + if need_runtime_check() { + lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + } + + Expr::::from_node(__current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.call(Func::Load, &[gep], T::type_()) + })) + } + pub fn len(&self) -> Expr { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => (*length as u32).expr(), + _ => unreachable!(), + } + } + pub fn static_len(&self) -> usize { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => *length, + _ => unreachable!(), + } + } + pub fn write>, V: Into>>(&self, i: I, value: V) { + let i = i.into(); + let value = value.into(); + + if need_runtime_check() { + lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + } + + __current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.update(gep, value.node()); + }); + } + pub fn load(&self) -> VLArrayExpr { + VLArrayExpr::from_node(__current_scope(|b| { + b.call(Func::Load, &[self.node], self.node.type_().clone()) + })) + } + pub fn store(&self, value: VLArrayExpr) { + __current_scope(|b| { + b.update(self.node, value.node); + }); + } + pub fn zero(length: usize) -> Self { + FromNode::from_node(__current_scope(|b| { + b.local_zero_init(ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length, + }))) + })) + } +} + +impl VLArrayExpr { + pub fn zero(length: usize) -> Self { + let node = __current_scope(|b| { + b.call( + Func::ZeroInitializer, + &[], + ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length, + })), + ) + }); + Self::from_node(node) + } + pub fn static_len(&self) -> usize { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => *length, + _ => unreachable!(), + } + } + pub fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + if need_runtime_check() { + lc_assert!(i.cmplt(self.len())); + } + + Expr::::from_node(__current_scope(|b| { + b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) + })) + } + pub fn len(&self) -> Expr { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => (*length as u64).expr(), + _ => unreachable!(), + } + } +} diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs new file mode 100644 index 0000000..9898c3e --- /dev/null +++ b/luisa_compute/src/lang/types/core.rs @@ -0,0 +1,163 @@ +use super::*; +use std::ops::Deref; + +// This is a hack in order to get rust-analyzer to display type hints as Expr +// instead of Expr, which is rather redundant and generally clutters things up. +pub(crate) mod prim { + use super::*; + + #[derive(Clone, Copy, Debug)] + pub struct Expr { + pub(crate) node: NodeRef, + pub(crate) _phantom: PhantomData, + } + + #[derive(Clone, Copy, Debug)] + pub struct Var { + pub(crate) node: NodeRef, + pub(crate) _phantom: PhantomData, + } +} + +impl Aggregate for prim::Expr { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self { + node: iter.next().unwrap(), + _phantom: PhantomData, + } + } +} + +impl Aggregate for prim::Var { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self { + node: iter.next().unwrap(), + _phantom: PhantomData, + } + } +} + +impl FromNode for prim::Expr { + fn from_node(node: NodeRef) -> Self { + Self { + node, + _phantom: PhantomData, + } + } +} +impl ToNode for prim::Expr { + fn node(&self) -> NodeRef { + self.node + } +} + +impl Deref for prim::Var +where + prim::Var: VarProxy, +{ + type Target = T::Expr; + fn deref(&self) -> &Self::Target { + self._deref() + } +} + +macro_rules! impl_prim { + ($t:ty) => { + impl From<$t> for prim::Expr<$t> { + fn from(v: $t) -> Self { + (v).expr() + } + } + impl From> for prim::Expr<$t> { + fn from(v: Var<$t>) -> Self { + v.load() + } + } + impl FromNode for prim::Var<$t> { + fn from_node(node: NodeRef) -> Self { + Self { + node, + _phantom: PhantomData, + } + } + } + impl ToNode for prim::Var<$t> { + fn node(&self) -> NodeRef { + self.node + } + } + impl ExprProxy for prim::Expr<$t> { + type Value = $t; + } + impl VarProxy for prim::Var<$t> { + type Value = $t; + } + impl Value for $t { + type Expr = prim::Expr<$t>; + type Var = prim::Var<$t>; + fn fields() -> Vec { + vec![] + } + } + impl_callable_param!($t, prim::Expr<$t>, prim::Var<$t>); + }; +} + +impl_prim!(bool); +impl_prim!(u32); +impl_prim!(u64); +impl_prim!(i32); +impl_prim!(i64); +impl_prim!(i16); +impl_prim!(u16); +impl_prim!(f16); +impl_prim!(f32); +impl_prim!(f64); + +pub type Bool = prim::Expr; +pub type F16 = prim::Expr; +pub type F32 = prim::Expr; +pub type F64 = prim::Expr; +pub type I16 = prim::Expr; +pub type I32 = prim::Expr; +pub type I64 = prim::Expr; +pub type U16 = prim::Expr; +pub type U32 = prim::Expr; +pub type U64 = prim::Expr; + +pub type F16Var = prim::Var; +pub type F32Var = prim::Var; +pub type F64Var = prim::Var; +pub type I16Var = prim::Var; +pub type I32Var = prim::Var; +pub type I64Var = prim::Var; +pub type U16Var = prim::Var; +pub type U32Var = prim::Var; +pub type U64Var = prim::Var; + +pub type Half = prim::Expr; +pub type Float = prim::Expr; +pub type Double = prim::Expr; +pub type Int = prim::Expr; +pub type Long = prim::Expr; +pub type Uint = prim::Expr; +pub type Ulong = prim::Expr; +pub type Short = prim::Expr; +pub type Ushort = prim::Expr; + +pub type BoolVar = prim::Var; +pub type HalfVar = prim::Var; +pub type FloatVar = prim::Var; +pub type DoubleVar = prim::Var; +pub type IntVar = prim::Var; +pub type LongVar = prim::Var; +pub type UintVar = prim::Var; +pub type UlongVar = prim::Var; +pub type ShortVar = prim::Var; +pub type UshortVar = prim::Var; diff --git a/luisa_compute/src/lang/types/dynamic.rs b/luisa_compute/src/lang/types/dynamic.rs new file mode 100644 index 0000000..1744779 --- /dev/null +++ b/luisa_compute/src/lang/types/dynamic.rs @@ -0,0 +1,210 @@ +use super::array::{VLArrayExpr, VLArrayVar}; +use super::*; +use ir::ArrayType; +use std::any::Any; +use std::rc::Rc; + +#[derive(Clone, Copy)] +pub struct DynExpr { + node: NodeRef, +} + +impl From for DynExpr { + fn from(value: T) -> Self { + Self { node: value.node() } + } +} + +impl From for DynVar { + fn from(value: T) -> Self { + Self { node: value.node() } + } +} + +impl DynExpr { + pub fn downcast(&self) -> Option> { + if ir::context::is_type_equal(self.node.type_(), &T::type_()) { + Some(Expr::::from_node(self.node)) + } else { + None + } + } + pub fn get(&self) -> Expr { + self.downcast::().unwrap_or_else(|| { + panic!( + "DynExpr::get: type mismatch: expected {}, got {}", + std::any::type_name::(), + self.node.type_().to_string() + ) + }) + } + pub fn downcast_array(&self, len: usize) -> Option> { + let array_type = ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length: len, + })); + if ir::context::is_type_equal(self.node.type_(), &array_type) { + Some(VLArrayExpr::::from_node(self.node)) + } else { + None + } + } + pub fn get_array(&self, len: usize) -> VLArrayExpr { + let array_type = ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length: len, + })); + self.downcast_array::(len).unwrap_or_else(|| { + panic!( + "DynExpr::get: type mismatch: expected {}, got {}", + array_type, + self.node.type_().to_string() + ) + }) + } + pub fn new(expr: E) -> Self { + Self { node: expr.node() } + } +} + +impl CallableParameter for DynExpr { + fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self { + let arg = arg.unwrap_or_else(|| panic!("DynExpr should be used in DynCallable only!")); + let arg = arg.downcast_ref::().unwrap(); + let node = builder.arg(arg.node.type_().clone(), true); + Self { node } + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.args.push(self.node) + } +} + +impl Aggregate for DynExpr { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node) + } + fn from_nodes>(iter: &mut I) -> Self { + Self { + node: iter.next().unwrap(), + } + } +} + +impl FromNode for DynExpr { + fn from_node(node: NodeRef) -> Self { + Self { node } + } +} + +impl ToNode for DynExpr { + fn node(&self) -> NodeRef { + self.node + } +} + +unsafe impl CallableRet for DynExpr { + fn _return(&self) -> CArc { + __current_scope(|b| { + b.return_(self.node); + }); + self.node.type_().clone() + } + fn _from_return(node: NodeRef) -> Self { + Self::from_node(node) + } +} + +impl Aggregate for DynVar { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node) + } + fn from_nodes>(iter: &mut I) -> Self { + Self { + node: iter.next().unwrap(), + } + } +} + +impl FromNode for DynVar { + fn from_node(node: NodeRef) -> Self { + Self { node } + } +} + +impl ToNode for DynVar { + fn node(&self) -> NodeRef { + self.node + } +} + +#[derive(Clone, Copy)] +pub struct DynVar { + node: NodeRef, +} + +impl CallableParameter for DynVar { + fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self { + let arg = arg.unwrap_or_else(|| panic!("DynVar should be used in DynCallable only!")); + let arg = arg.downcast_ref::().unwrap(); + let node = builder.arg(arg.node.type_().clone(), false); + Self { node } + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.args.push(self.node) + } +} + +impl DynVar { + pub fn downcast(&self) -> Option> { + if ir::context::is_type_equal(self.node.type_(), &T::type_()) { + Some(Var::::from_node(self.node)) + } else { + None + } + } + pub fn get(&self) -> Var { + self.downcast::().unwrap_or_else(|| { + panic!( + "DynVar::get: type mismatch: expected {}, got {}", + std::any::type_name::(), + self.node.type_().to_string() + ) + }) + } + pub fn downcast_array(&self, len: usize) -> Option> { + let array_type = ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length: len, + })); + if ir::context::is_type_equal(self.node.type_(), &array_type) { + Some(VLArrayVar::::from_node(self.node)) + } else { + None + } + } + pub fn get_array(&self, len: usize) -> VLArrayVar { + let array_type = ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length: len, + })); + self.downcast_array::(len).unwrap_or_else(|| { + panic!( + "DynExpr::get: type mismatch: expected {}, got {}", + array_type, + self.node.type_().to_string() + ) + }) + } + pub fn load(&self) -> DynExpr { + DynExpr { + node: __current_scope(|b| b.call(Func::Load, &[self.node], self.node.type_().clone())), + } + } + pub fn store(&self, value: &DynExpr) { + __current_scope(|b| b.update(self.node, value.node)); + } + pub fn zero() -> Self { + let v = local_zeroed::(); + Self { node: v.node() } + } +} diff --git a/luisa_compute/src/lang/types/shared.rs b/luisa_compute/src/lang/types/shared.rs new file mode 100644 index 0000000..a374d69 --- /dev/null +++ b/luisa_compute/src/lang/types/shared.rs @@ -0,0 +1,68 @@ +use super::array::VLArrayExpr; +use super::*; +use crate::lang::index::IntoIndex; +use ir::ArrayType; + +pub struct Shared { + marker: PhantomData, + node: NodeRef, +} +impl Shared { + pub fn new(length: usize) -> Self { + Self { + marker: PhantomData, + node: __current_scope(|b| { + let shared = new_node( + b.pools(), + Node::new( + CArc::new(Instruction::Shared), + ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length, + })), + ), + ); + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + r.shared.push(shared); + }); + shared + }), + } + } + pub fn len(&self) -> Expr { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => (*length as u64).expr(), + _ => unreachable!(), + } + } + pub fn static_len(&self) -> usize { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => *length, + _ => unreachable!(), + } + } + pub fn write>>(&self, i: I, value: V) { + let i = i.to_u64(); + let value = value.into(); + + if need_runtime_check() { + lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + } + + __current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.update(gep, value.node()); + }); + } + pub fn load(&self) -> VLArrayExpr { + VLArrayExpr::from_node(__current_scope(|b| { + b.call(Func::Load, &[self.node], self.node.type_().clone()) + })) + } + pub fn store(&self, value: VLArrayExpr) { + __current_scope(|b| { + b.update(self.node, value.node); + }); + } +} diff --git a/luisa_compute/src/lang/math.rs b/luisa_compute/src/lang/types/vector.rs similarity index 86% rename from luisa_compute/src/lang/math.rs rename to luisa_compute/src/lang/types/vector.rs index 6c4e0f4..a1505ea 100644 --- a/luisa_compute/src/lang/math.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -1,12 +1,6 @@ -pub use super::swizzle::*; -use super::{Aggregate, ExprProxy, Value, VarProxy, __extract, traits::*, Float}; -use crate::*; -use half::f16; -use luisa_compute_ir::{ - context::register_type, - ir::{Func, MatrixType, NodeRef, Primitive, Type, VectorElementType, VectorType}, - TypeOf, -}; +use super::core::*; +use super::*; +use ir::{MatrixType, Primitive, VectorElementType, VectorType}; use serde::{Deserialize, Serialize}; use std::ops::Mul; @@ -44,7 +38,7 @@ macro_rules! def_vec { macro_rules! def_packed_vec { ($name:ident, $vec_type:ident, $glam_type:ident, $scalar:ty, $($comp:ident), *) => { #[repr(C)] - #[derive(Copy, Clone, Debug, Default, __Value,PartialEq, Serialize, Deserialize)] + #[derive(Copy, Clone, Debug, Default, Value, PartialEq, Serialize, Deserialize)] pub struct $name { $(pub $comp: $scalar), * } @@ -87,7 +81,7 @@ macro_rules! def_packed_vec { macro_rules! def_packed_vec_no_glam { ($name:ident, $vec_type:ident, $scalar:ty, $($comp:ident), *) => { #[repr(C)] - #[derive(Copy, Clone, Debug, Default, __Value)] + #[derive(Copy, Clone, Debug, Default, Value)] pub struct $name { $(pub $comp: $scalar), * } @@ -363,11 +357,11 @@ macro_rules! impl_proxy_fields { ($vec:ident, $proxy:ident, $scalar:ty, x) => { impl $proxy { #[inline] - pub fn x(&self) -> Expr<$scalar> { + pub fn x(&self) -> prim::Expr<$scalar> { FromNode::from_node(__extract::<$scalar>(self.node, 0)) } #[inline] - pub fn set_x(&self, value: Expr<$scalar>) -> Self { + pub fn set_x(&self, value: prim::Expr<$scalar>) -> Self { Self::from_node(__insert::<$vec>(self.node, 0, ToNode::node(&value))) } } @@ -375,11 +369,11 @@ macro_rules! impl_proxy_fields { ($vec:ident,$proxy:ident, $scalar:ty, y) => { impl $proxy { #[inline] - pub fn y(&self) -> Expr<$scalar> { + pub fn y(&self) -> prim::Expr<$scalar> { FromNode::from_node(__extract::<$scalar>(self.node, 1)) } #[inline] - pub fn set_y(&self, value: Expr<$scalar>) -> Self { + pub fn set_y(&self, value: prim::Expr<$scalar>) -> Self { Self::from_node(__insert::<$vec>(self.node, 1, ToNode::node(&value))) } } @@ -387,11 +381,11 @@ macro_rules! impl_proxy_fields { ($vec:ident,$proxy:ident, $scalar:ty, z) => { impl $proxy { #[inline] - pub fn z(&self) -> Expr<$scalar> { + pub fn z(&self) -> prim::Expr<$scalar> { FromNode::from_node(__extract::<$scalar>(self.node, 2)) } #[inline] - pub fn set_z(&self, value: Expr<$scalar>) -> Self { + pub fn set_z(&self, value: prim::Expr<$scalar>) -> Self { Self::from_node(__insert::<$vec>(self.node, 2, ToNode::node(&value))) } } @@ -399,11 +393,11 @@ macro_rules! impl_proxy_fields { ($vec:ident,$proxy:ident, $scalar:ty, w) => { impl $proxy { #[inline] - pub fn w(&self) -> Expr<$scalar> { + pub fn w(&self) -> prim::Expr<$scalar> { FromNode::from_node(__extract::<$scalar>(self.node, 3)) } #[inline] - pub fn set_w(&self, value: Expr<$scalar>) -> Self { + pub fn set_w(&self, value: prim::Expr<$scalar>) -> Self { Self::from_node(__insert::<$vec>(self.node, 3, ToNode::node(&value))) } } @@ -481,7 +475,7 @@ macro_rules! impl_vec_proxy { } impl VectorVarTrait for $expr_proxy { } impl ScalarOrVector for $expr_proxy { - type Element = Expr<$scalar>; + type Element = prim::Expr<$scalar>; type ElementHost = $scalar; } impl BuiltinVarTrait for $expr_proxy { } @@ -537,15 +531,21 @@ macro_rules! impl_vec_proxy { $(impl_var_proxy_fields!($var_proxy, $scalar, $comp);)* impl $expr_proxy { #[inline] - pub fn new($($comp: Expr<$scalar>), *) -> Self { + pub fn new($($comp: prim::Expr<$scalar>), *) -> Self { Self { node: __compose::<$vec>(&[$(ToNode::node(&$comp)), *]), } } - pub fn at(&self, index: usize) -> Expr<$scalar> { + pub fn at(&self, index: usize) -> prim::Expr<$scalar> { FromNode::from_node(__extract::<$scalar>(self.node, index)) } } + impl $vec { + #[inline] + pub fn expr($($comp: impl Into>), *) -> $expr_proxy { + $expr_proxy::new($($comp.into()), *) + } + } }; } @@ -646,6 +646,12 @@ macro_rules! impl_mat_proxy { Expr::<$vec>::from_node(__extract::<$vec>(self.node, index)) } } + impl $mat { + #[inline] + pub fn expr($($comp: impl Into>), *) -> $expr_proxy { + $expr_proxy::new($($comp.into()), *) + } + } }; } @@ -663,7 +669,18 @@ impl_vec_proxy!(Float4, Float4Expr, Float4Var, f32, Float32, 4, x, y, z, w); impl_vec_proxy!(Double2, Double2Expr, Double2Var, f64, Float64, 2, x, y); impl_vec_proxy!(Double3, Double3Expr, Double3Var, f64, Float64, 3, x, y, z); -impl_vec_proxy!(Double4, Double4Expr, Double4Var, f64, Float64, 4, x, y, z, w); +impl_vec_proxy!( + Double4, + Double4Expr, + Double4Var, + f64, + Float64, + 4, + x, + y, + z, + w +); impl_vec_proxy!(Ushort2, Ushort2Expr, Ushort2Var, u16, Uint16, 2, x, y); impl_vec_proxy!(Ushort3, Ushort3Expr, Ushort3Var, u16, Uint16, 3, x, y, z); @@ -784,16 +801,16 @@ macro_rules! impl_binop { })) } } - impl std::ops::$tr> for $proxy { + impl std::ops::$tr> for $proxy { type Output = $proxy; - fn $m(self, rhs: PrimExpr<$scalar>) -> Self::Output { + fn $m(self, rhs: prim::Expr<$scalar>) -> Self::Output { let rhs = Self::splat(rhs); <$proxy>::from_node(__current_scope(|s| { s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) })) } } - impl std::ops::$tr<$proxy> for PrimExpr<$scalar> { + impl std::ops::$tr<$proxy> for prim::Expr<$scalar> { type Output = $proxy; fn $m(self, rhs: $proxy) -> Self::Output { let lhs = <$proxy>::splat(self); @@ -838,16 +855,16 @@ macro_rules! impl_binop_for_mat { })) } } - impl std::ops::$tr> for $proxy { + impl std::ops::$tr> for $proxy { type Output = $proxy; - fn $m(self, rhs: PrimExpr<$scalar>) -> Self::Output { + fn $m(self, rhs: prim::Expr<$scalar>) -> Self::Output { let rhs = Self::fill(rhs); <$proxy>::from_node(__current_scope(|s| { s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) })) } } - impl std::ops::$tr<$proxy> for PrimExpr<$scalar> { + impl std::ops::$tr<$proxy> for prim::Expr<$scalar> { type Output = $proxy; fn $m(self, rhs: $proxy) -> Self::Output { let lhs = <$proxy>::fill(self); @@ -929,9 +946,9 @@ macro_rules! impl_arith_binop_for_mat { })) } } - impl std::ops::Mul> for $proxy { + impl std::ops::Mul> for $proxy { type Output = $proxy; - fn mul(self, rhs: PrimExpr<$scalar>) -> Self::Output { + fn mul(self, rhs: prim::Expr<$scalar>) -> Self::Output { let rhs = Self::fill(rhs); <$proxy>::from_node(__current_scope(|s| { s.call( @@ -942,7 +959,7 @@ macro_rules! impl_arith_binop_for_mat { })) } } - impl std::ops::Mul<$proxy> for PrimExpr<$scalar> { + impl std::ops::Mul<$proxy> for prim::Expr<$scalar> { type Output = $proxy; fn mul(self, rhs: $proxy) -> Self::Output { let lhs = <$proxy>::fill(self); @@ -965,15 +982,15 @@ macro_rules! impl_arith_binop_for_mat { impl std::ops::Rem<$scalar> for $proxy { type Output = $proxy; fn rem(self, rhs: $scalar) -> Self::Output { - let rhs: PrimExpr<$scalar> = rhs.into(); + let rhs: prim::Expr<$scalar> = rhs.into(); <$proxy>::from_node(__current_scope(|s| { s.call(Func::Rem, &[self.node, rhs.node], <$t as TypeOf>::type_()) })) } } - impl std::ops::Rem> for $proxy { + impl std::ops::Rem> for $proxy { type Output = $proxy; - fn rem(self, rhs: PrimExpr<$scalar>) -> Self::Output { + fn rem(self, rhs: prim::Expr<$scalar>) -> Self::Output { <$proxy>::from_node(__current_scope(|s| { s.call(Func::Rem, &[self.node, rhs.node], <$t as TypeOf>::type_()) })) @@ -989,15 +1006,15 @@ macro_rules! impl_arith_binop_for_mat { impl std::ops::Div<$scalar> for $proxy { type Output = $proxy; fn div(self, rhs: $scalar) -> Self::Output { - let rhs: PrimExpr<$scalar> = rhs.into(); + let rhs: prim::Expr<$scalar> = rhs.into(); <$proxy>::from_node(__current_scope(|s| { s.call(Func::Div, &[self.node, rhs.node], <$t as TypeOf>::type_()) })) } } - impl std::ops::Div> for $proxy { + impl std::ops::Div> for $proxy { type Output = $proxy; - fn div(self, rhs: PrimExpr<$scalar>) -> Self::Output { + fn div(self, rhs: prim::Expr<$scalar>) -> Self::Output { <$proxy>::from_node(__current_scope(|s| { s.call(Func::Div, &[self.node, rhs.node], <$t as TypeOf>::type_()) })) @@ -1081,7 +1098,7 @@ macro_rules! impl_bool_binop { bitxor_assign ); impl $proxy { - pub fn splat>>(value: V) -> Self { + pub fn splat>>(value: V) -> Self { let value = value.into(); <$proxy>::from_node(__current_scope(|s| { s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) @@ -1093,12 +1110,12 @@ macro_rules! impl_bool_binop { pub fn one() -> Self { Self::splat(true) } - pub fn all(&self) -> Expr { + pub fn all(&self) -> prim::Expr { Expr::::from_node(__current_scope(|s| { s.call(Func::All, &[self.node], ::type_()) })) } - pub fn any(&self) -> Expr { + pub fn any(&self) -> prim::Expr { Expr::::from_node(__current_scope(|s| { s.call(Func::Any, &[self.node], ::type_()) })) @@ -1116,31 +1133,31 @@ macro_rules! impl_reduce { ($t:ty, $scalar:ty, $proxy:ty) => { impl $proxy { #[inline] - pub fn reduce_sum(&self) -> Expr<$scalar> { + pub fn reduce_sum(&self) -> prim::Expr<$scalar> { FromNode::from_node(__current_scope(|s| { s.call(Func::ReduceSum, &[self.node], <$scalar as TypeOf>::type_()) })) } #[inline] - pub fn reduce_prod(&self) -> Expr<$scalar> { + pub fn reduce_prod(&self) -> prim::Expr<$scalar> { FromNode::from_node(__current_scope(|s| { s.call(Func::ReduceProd, &[self.node], <$scalar as TypeOf>::type_()) })) } #[inline] - pub fn reduce_min(&self) -> Expr<$scalar> { + pub fn reduce_min(&self) -> prim::Expr<$scalar> { FromNode::from_node(__current_scope(|s| { s.call(Func::ReduceMin, &[self.node], <$scalar as TypeOf>::type_()) })) } #[inline] - pub fn reduce_max(&self) -> Expr<$scalar> { + pub fn reduce_max(&self) -> prim::Expr<$scalar> { FromNode::from_node(__current_scope(|s| { s.call(Func::ReduceMax, &[self.node], <$scalar as TypeOf>::type_()) })) } #[inline] - pub fn dot(&self, rhs: $proxy) -> Expr<$scalar> { + pub fn dot(&self, rhs: $proxy) -> prim::Expr<$scalar> { FromNode::from_node(__current_scope(|s| { s.call( Func::Dot, @@ -1155,7 +1172,7 @@ macro_rules! impl_reduce { macro_rules! impl_common_op { ($t:ty, $scalar:ty, $proxy:ty) => { impl $proxy { - pub fn splat>>(value: V) -> Self { + pub fn splat>>(value: V) -> Self { let value = value.into(); <$proxy>::from_node(__current_scope(|s| { s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) @@ -1174,7 +1191,7 @@ macro_rules! impl_vec_op { ($t:ty, $scalar:ty, $proxy:ty, $mat:ty) => { impl $proxy { #[inline] - pub fn length(&self) -> Expr<$scalar> { + pub fn length(&self) -> prim::Expr<$scalar> { FromNode::from_node(__current_scope(|s| { s.call(Func::Length, &[self.node], <$scalar as TypeOf>::type_()) })) @@ -1186,7 +1203,7 @@ macro_rules! impl_vec_op { })) } #[inline] - pub fn length_squared(&self) -> Expr<$scalar> { + pub fn length_squared(&self) -> prim::Expr<$scalar> { FromNode::from_node(__current_scope(|s| { s.call( Func::LengthSquared, @@ -1196,11 +1213,11 @@ macro_rules! impl_vec_op { })) } #[inline] - pub fn distance(&self, rhs: $proxy) -> Expr<$scalar> { + pub fn distance(&self, rhs: $proxy) -> prim::Expr<$scalar> { (*self - rhs).length() } #[inline] - pub fn distance_squared(&self, rhs: $proxy) -> Expr<$scalar> { + pub fn distance_squared(&self, rhs: $proxy) -> prim::Expr<$scalar> { (*self - rhs).length_squared() } #[inline] @@ -1250,7 +1267,7 @@ macro_rules! impl_arith_binop_f16 { macro_rules! impl_common_op_f16 { ($t:ty, $scalar:ty, $proxy:ty) => { impl $proxy { - pub fn splat>>(value: V) -> Self { + pub fn splat>>(value: V) -> Self { let value = value.into(); <$proxy>::from_node(__current_scope(|s| { s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) @@ -1512,7 +1529,7 @@ macro_rules! impl_var_trait2 { impl VarCmpEq for $t {} impl From<$v> for $t { fn from(v: $v) -> Self { - Self::new(const_(v.x), const_(v.y)) + Self::new((v.x).expr(), (v.y).expr()) } } }; @@ -1537,7 +1554,7 @@ macro_rules! impl_var_trait3 { impl VarCmpEq for $t {} impl From<$v> for $t { fn from(v: $v) -> Self { - Self::new(const_(v.x), const_(v.y), const_(v.z)) + Self::new(v.x.expr(), v.y.expr(), v.z.expr()) } } }; @@ -1562,7 +1579,7 @@ macro_rules! impl_var_trait4 { impl VarCmpEq for $t {} impl From<$v> for $t { fn from(v: $v) -> Self { - Self::new(const_(v.x), const_(v.y), const_(v.z), const_(v.w)) + Self::new(v.x.expr(), v.y.expr(), v.z.expr(), v.w.expr()) } } }; @@ -1667,11 +1684,11 @@ impl Mul for Mat2Expr { } } impl Mat2Expr { - pub fn fill(e: impl Into> + Copy) -> Self { - Self::new(make_float2(e, e), make_float2(e, e)) + pub fn fill(e: impl Into> + Copy) -> Self { + Self::new(Float2::expr(e, e), Float2::expr(e, e)) } pub fn eye(e: Expr) -> Self { - Self::new(make_float2(e.x(), 0.0), make_float2(0.0, e.y())) + Self::new(Float2::expr(e.x(), 0.0), Float2::expr(0.0, e.y())) } pub fn inverse(&self) -> Self { Mat2Expr::from_node(__current_scope(|s| { @@ -1683,7 +1700,7 @@ impl Mat2Expr { s.call(Func::Transpose, &[self.node], ::type_()) })) } - pub fn determinant(&self) -> Float { + pub fn determinant(&self) -> prim::Expr { FromNode::from_node(__current_scope(|s| { s.call(Func::Determinant, &[self.node], ::type_()) })) @@ -1704,18 +1721,18 @@ impl Mul for Mat3Expr { } } impl Mat3Expr { - pub fn fill(e: impl Into> + Copy) -> Self { + pub fn fill(e: impl Into> + Copy) -> Self { Self::new( - make_float3(e, e, e), - make_float3(e, e, e), - make_float3(e, e, e), + Float3::expr(e, e, e), + Float3::expr(e, e, e), + Float3::expr(e, e, e), ) } pub fn eye(e: Expr) -> Self { Self::new( - make_float3(e.x(), 0.0, 0.0), - make_float3(0.0, e.y(), 0.0), - make_float3(0.0, 0.0, e.z()), + Float3::expr(e.x(), 0.0, 0.0), + Float3::expr(0.0, e.y(), 0.0), + Float3::expr(0.0, 0.0, e.z()), ) } pub fn inverse(&self) -> Self { @@ -1728,7 +1745,7 @@ impl Mat3Expr { s.call(Func::Transpose, &[self.node], ::type_()) })) } - pub fn determinant(&self) -> Float { + pub fn determinant(&self) -> prim::Expr { FromNode::from_node(__current_scope(|s| { s.call(Func::Determinant, &[self.node], ::type_()) })) @@ -1749,20 +1766,20 @@ impl Mul for Mat4Expr { } impl_arith_binop_for_mat!(Mat3, f32, Mat3Expr); impl Mat4Expr { - pub fn fill(e: impl Into> + Copy) -> Self { + pub fn fill(e: impl Into> + Copy) -> Self { Self::new( - make_float4(e, e, e, e), - make_float4(e, e, e, e), - make_float4(e, e, e, e), - make_float4(e, e, e, e), + Float4::expr(e, e, e, e), + Float4::expr(e, e, e, e), + Float4::expr(e, e, e, e), + Float4::expr(e, e, e, e), ) } pub fn eye(e: Expr) -> Self { Self::new( - make_float4(e.x(), 0.0, 0.0, 0.0), - make_float4(0.0, e.y(), 0.0, 0.0), - make_float4(0.0, 0.0, e.z(), 0.0), - make_float4(0.0, 0.0, 0.0, e.w()), + Float4::expr(e.x(), 0.0, 0.0, 0.0), + Float4::expr(0.0, e.y(), 0.0, 0.0), + Float4::expr(0.0, 0.0, e.z(), 0.0), + Float4::expr(0.0, 0.0, 0.0, e.w()), ) } pub fn inverse(&self) -> Self { @@ -1775,7 +1792,7 @@ impl Mat4Expr { s.call(Func::Transpose, &[self.node], ::type_()) })) } - pub fn determinant(&self) -> Float { + pub fn determinant(&self) -> prim::Expr { FromNode::from_node(__current_scope(|s| { s.call(Func::Determinant, &[self.node], ::type_()) })) @@ -1783,199 +1800,11 @@ impl Mat4Expr { } impl_arith_binop_for_mat!(Mat4, f32, Mat4Expr); -#[inline] -pub fn make_half2>, Y: Into>>(x: X, y: Y) -> Expr { - Expr::::new(x.into(), y.into()) -} -#[inline] -pub fn make_half3>, Y: Into>, Z: Into>>( - x: X, - y: Y, - z: Z, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into()) -} -#[inline] -pub fn make_half4< - X: Into>, - Y: Into>, - Z: Into>, - W: Into>, ->( - x: X, - y: Y, - z: Z, - w: W, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into(), w.into()) -} - -#[inline] -pub fn make_float2>, Y: Into>>(x: X, y: Y) -> Expr { - Expr::::new(x.into(), y.into()) -} -#[inline] -pub fn make_float3>, Y: Into>, Z: Into>>( - x: X, - y: Y, - z: Z, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into()) -} -#[inline] -pub fn make_float4< - X: Into>, - Y: Into>, - Z: Into>, - W: Into>, ->( - x: X, - y: Y, - z: Z, - w: W, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into(), w.into()) -} -#[inline] -pub fn make_float2x2>, Y: Into>>(x: X, y: Y) -> Expr { - Expr::::new(x.into(), y.into()) -} -#[inline] -pub fn make_float3x3>, Y: Into>, Z: Into>>( - x: X, - y: Y, - z: Z, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into()) -} -#[inline] -pub fn make_float4x4< - X: Into>, - Y: Into>, - Z: Into>, - W: Into>, ->( - x: X, - y: Y, - z: Z, - w: W, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into(), w.into()) -} - -#[inline] -pub fn make_int2>, Y: Into>>(x: X, y: Y) -> Expr { - Expr::::new(x.into(), y.into()) -} -#[inline] -pub fn make_int3>, Y: Into>, Z: Into>>( - x: X, - y: Y, - z: Z, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into()) -} -#[inline] -pub fn make_int4< - X: Into>, - Y: Into>, - Z: Into>, - W: Into>, ->( - x: X, - y: Y, - z: Z, - w: W, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into(), w.into()) -} -#[inline] -pub fn make_uint2>, Y: Into>>(x: X, y: Y) -> Expr { - Expr::::new(x.into(), y.into()) -} -#[inline] -pub fn make_uint3>, Y: Into>, Z: Into>>( - x: X, - y: Y, - z: Z, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into()) -} -#[inline] -pub fn make_uint4< - X: Into>, - Y: Into>, - Z: Into>, - W: Into>, ->( - x: X, - y: Y, - z: Z, - w: W, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into(), w.into()) -} - -#[inline] -pub fn make_short2>, Y: Into>>(x: X, y: Y) -> Expr { - Expr::::new(x.into(), y.into()) -} -#[inline] -pub fn make_short3>, Y: Into>, Z: Into>>( - x: X, - y: Y, - z: Z, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into()) -} -#[inline] -pub fn make_short4< - X: Into>, - Y: Into>, - Z: Into>, - W: Into>, ->( - x: X, - y: Y, - z: Z, - w: W, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into(), w.into()) -} - -#[inline] -pub fn make_ushort2>, Y: Into>>(x: X, y: Y) -> Expr { - Expr::::new(x.into(), y.into()) -} -#[inline] -pub fn make_ushort3>, Y: Into>, Z: Into>>( - x: X, - y: Y, - z: Z, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into()) -} -#[inline] -pub fn make_ushort4< - X: Into>, - Y: Into>, - Z: Into>, - W: Into>, ->( - x: X, - y: Y, - z: Z, - w: W, -) -> Expr { - Expr::::new(x.into(), y.into(), z.into(), w.into()) -} - #[cfg(test)] mod test { #[test] fn test_size() { - use crate::prelude::*; - use crate::*; + use crate::internal_prelude::*; macro_rules! assert_size { ($ty:ty) => { {assert_eq!(std::mem::size_of::<$ty>(), <$ty as TypeOf>::type_().size());} diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index f55518b..cb9f853 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -1,67 +1,81 @@ #![allow(unused_unsafe)] +extern crate self as luisa_compute; + +use std::any::Any; use std::backtrace::Backtrace; use std::path::Path; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; pub mod lang; +pub mod printer; pub mod resource; pub mod rtx; pub mod runtime; -pub use half::f16; -use luisa_compute_api_types as api; -pub use luisa_compute_backend as backend; - pub mod prelude { - pub use crate::lang::poly::PolymorphicImpl; - pub use crate::lang::traits::VarTrait; - pub use crate::lang::traits::{CommonVarOp, FloatVarTrait, IntVarTrait, VarCmp, VarCmpEq}; - pub use crate::lang::{ - Aggregate, ExprProxy, FromNode, IndexRead, IndexWrite, KernelBuildFn, KernelParameter, - KernelSignature, Mask, Value, VarProxy, + pub use half::f16; + + pub use crate::lang::control_flow::{ + break_, continue_, for_range, return_, return_v, select, switch, }; - pub use crate::lang::{ - __compose, __cpu_dbg, __current_scope, __env_need_backtrace, __extract, __insert, - __module_pools, __new_user_node, __pop_scope, + pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size}; + pub use crate::lang::index::{IndexRead, IndexWrite}; + pub use crate::lang::ops::*; + pub use crate::lang::swizzle::*; + pub use crate::lang::types::vector::{ + Bool2, Bool3, Bool4, Byte2, Byte3, Byte4, Double2, Double3, Double4, Float2, Float3, + Float4, Half2, Half3, Half4, Int2, Int3, Int4, Long2, Long3, Long4, Mat2, Mat3, Mat4, + PackedBool2, PackedBool3, PackedBool4, PackedFloat2, PackedFloat3, PackedFloat4, + PackedInt2, PackedInt3, PackedInt4, PackedLong2, PackedLong3, PackedLong4, PackedShort2, + PackedShort3, PackedShort4, PackedUint2, PackedUint3, PackedUint4, PackedUlong2, + PackedUlong3, PackedUlong4, PackedUshort2, PackedUshort3, PackedUshort4, Short2, Short3, + Short4, Ubyte2, Ubyte3, Ubyte4, Uint2, Uint3, Uint4, Ulong2, Ulong3, Ulong4, Ushort2, + Ushort3, Ushort4, }; - pub use crate::resource::{IoTexel, StorageTexel}; - pub use crate::runtime::KernelArg; - pub use luisa_compute_ir::TypeOf; + pub use crate::lang::types::{Expr, ExprProxy, Value, Var, VarProxy}; + pub use crate::lang::Aggregate; + pub use crate::resource::{IoTexel, StorageTexel, *}; + pub use crate::runtime::{ + create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, + }; + pub use crate::{cpu_dbg, if_, lc_assert, lc_unreachable, loop_, struct_, while_, Context}; + + pub use luisa_compute_derive::*; + pub use luisa_compute_track::track; } -pub use api::{ - AccelBuildModificationFlags, AccelBuildRequest, AccelOption, AccelUsageHint, MeshType, - PixelFormat, PixelStorage, -}; -pub use glam; -pub use lang::math; -pub use lang::math::*; -pub use lang::poly; -pub use lang::poly::*; -pub use lang::traits::*; -pub use lang::*; -pub use log; -pub use luisa_compute_derive as derive; -pub use luisa_compute_derive::*; -pub use luisa_compute_ir::ir::UserNodeData; -pub use resource::*; -pub use runtime::*; - -pub mod macros { - pub use crate::{ - cpu_dbg, if_, impl_new_poly_array, impl_polymorphic, lc_assert, lc_dbg, lc_unreachable, - loop_, struct_, var, while_, +mod internal_prelude { + pub(crate) use crate::lang::debug::{CpuFn, __env_need_backtrace, is_cpu_backend}; + pub(crate) use crate::lang::ir::ffi::*; + pub(crate) use crate::lang::ir::{ + new_node, register_type, BasicBlock, Const, Func, Instruction, IrBuilder, Node, + PhiIncoming, Pooled, Type, TypeOf, INVALID_REF, }; - pub use crate::{lc_debug, lc_error, lc_info, lc_log, lc_warn}; + pub(crate) use crate::lang::types::vector::*; + pub(crate) use crate::lang::{ + ir, Recorder, __compose, __extract, __insert, __module_pools, need_runtime_check, FromNode, + NodeRef, ToNode, __current_scope, __pop_scope, RECORDER, + }; + pub(crate) use crate::prelude::*; + pub(crate) use crate::runtime::{ + CallableArgEncoder, CallableParameter, CallableRet, KernelBuilder, + }; + pub(crate) use crate::{get_backtrace, impl_callable_param, ResourceTracker}; + pub(crate) use luisa_compute_backend::Backend; + pub(crate) use std::marker::PhantomData; } +pub use luisa_compute_derive::*; + +use luisa_compute_api_types as api; +pub use {luisa_compute_backend as backend, luisa_compute_sys as sys}; + use lazy_static::lazy_static; use luisa_compute_backend::Backend; -pub use luisa_compute_sys as sys; use parking_lot::lock_api::RawMutex as RawMutexTrait; use parking_lot::{Mutex, RawMutex}; -pub use runtime::{Device, Scope, Stream}; +use runtime::{Device, DeviceHandle, StreamHandle}; use std::collections::HashMap; use std::sync::Weak; diff --git a/luisa_compute/src/lang/printer.rs b/luisa_compute/src/printer.rs similarity index 93% rename from luisa_compute/src/lang/printer.rs rename to luisa_compute/src/printer.rs index b4100aa..61cdc76 100644 --- a/luisa_compute/src/lang/printer.rs +++ b/luisa_compute/src/printer.rs @@ -1,8 +1,16 @@ -use crate::prelude::*; -use crate::*; use parking_lot::RwLock; use std::fmt::Debug; use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +#[doc(hidden)] +pub use log as _log; + +use crate::internal_prelude::*; + +use crate::lang::{pack_to, packed_size}; + +pub use crate::{lc_debug, lc_error, lc_info, lc_warn}; pub type LogFn = Box; struct PrinterItem { @@ -64,7 +72,7 @@ macro_rules! lc_log { // )* // $printer._log($level, printer_args, log_fn); // } - luisa_compute::derive::_log!( + $crate::_log!( $printer, $level, $fmt, @@ -113,7 +121,7 @@ impl Printer { }), } } - pub fn _log(&self, level: log::Level, args: PrinterArgs, log_fn: LogFn) { + pub fn _log(&self, _level: log::Level, args: PrinterArgs, log_fn: LogFn) { let inner = &self.inner; let data = inner.data.var(); let offset = data.atomic_fetch_add(1, 1 + args.count as u32); @@ -122,7 +130,8 @@ impl Printer { let item_id = items.len() as u32; if_!( - offset.cmplt(data.len().uint()) & (offset + 1 + args.count as u32).cmple(data.len().uint()), + offset.cmplt(data.len().uint()) + & (offset + 1 + args.count as u32).cmple(data.len().uint()), { data.atomic_fetch_add(0, 1); data.write(offset, item_id); @@ -192,7 +201,7 @@ impl<'a> Scope<'a> { let items = data.items.read(); let mut i = 2; let item_count = host_data[0] as usize; - for j in 0..item_count { + for _j in 0..item_count { if i >= host_data.len() { break; } diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 52fbb40..63848be 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -1,17 +1,19 @@ -use crate::macros::lc_assert; -use crate::*; -use api::BufferDownloadCommand; -use api::BufferUploadCommand; -use api::INVALID_RESOURCE_HANDLE; -use lang::Value; -use libc::c_void; -use runtime::*; -use std::cell::Cell; -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use std::ops::RangeBounds; use std::process::abort; - use std::sync::Arc; + +use parking_lot::lock_api::RawMutex as RawMutexTrait; +use parking_lot::RawMutex; + +use crate::internal_prelude::*; + +use crate::lang::index::IntoIndex; +use crate::runtime::*; + +use api::{BufferDownloadCommand, BufferUploadCommand, INVALID_RESOURCE_HANDLE}; +use libc::c_void; + pub struct ByteBuffer { pub(crate) device: Device, pub(crate) handle: Arc, @@ -110,7 +112,7 @@ impl<'a> ByteBufferView<'a> { size: data.len(), data: data.as_mut_ptr() as *mut u8, }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: None, } @@ -141,7 +143,7 @@ impl<'a> ByteBufferView<'a> { size: data.len(), data: data.as_ptr() as *const u8, }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: None, } @@ -168,7 +170,7 @@ impl<'a> ByteBufferView<'a> { dst_offset: dst.offset, size: self.len, }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: None, } @@ -242,7 +244,7 @@ pub struct Buffer { pub(crate) device: Device, pub(crate) handle: Arc, pub(crate) len: usize, - pub(crate) _marker: std::marker::PhantomData, + pub(crate) _marker: PhantomData, } pub(crate) struct BufferHandle { pub(crate) device: Device, @@ -281,7 +283,7 @@ impl<'a, T: Value> BufferView<'a, T> { size: data.len() * std::mem::size_of::(), data: data.as_mut_ptr() as *mut u8, }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: None, } @@ -312,7 +314,7 @@ impl<'a, T: Value> BufferView<'a, T> { size: data.len() * std::mem::size_of::(), data: data.as_ptr() as *const u8, }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: None, } @@ -339,7 +341,7 @@ impl<'a, T: Value> BufferView<'a, T> { dst_offset: dst.offset * std::mem::size_of::(), size: self.len * std::mem::size_of::(), }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: None, } @@ -359,7 +361,7 @@ impl Buffer { device: self.device.clone(), handle: self.handle.clone(), len: self.len, - _marker: std::marker::PhantomData, + _marker: PhantomData, } } #[inline] @@ -461,18 +463,18 @@ pub(crate) struct BindlessArraySlot { } pub struct BufferHeap { pub(crate) inner: BindlessArray, - pub(crate) _marker: std::marker::PhantomData, + pub(crate) _marker: PhantomData, } pub struct BufferHeapVar { inner: BindlessArrayVar, - _marker: std::marker::PhantomData, + _marker: PhantomData, } impl BufferHeap { #[inline] pub fn var(&self) -> BufferHeapVar { BufferHeapVar { inner: self.inner.var(), - _marker: std::marker::PhantomData, + _marker: PhantomData, } } #[inline] @@ -785,7 +787,7 @@ impl BindlessArray { modifications: modifications.as_ptr(), modifications_count: modifications.len(), }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: Some(Box::new(move || unsafe { lock.unlock(); @@ -912,7 +914,7 @@ impl_io_texel!( f32, Float4, |x: Float4Expr| x.xy(), - |x: Float2Expr| { make_float4(x.x(), x.y(), 0.0, 0.0) } + |x: Float2Expr| { Float4::expr(x.x(), x.y(), 0.0, 0.0) } ); impl_io_texel!(Float4, f32, Float4, |x: Float4Expr| x, |x: Float4Expr| x); @@ -927,10 +929,10 @@ impl_io_texel!(u32, u32, Uint4, |x: Uint4Expr| x.x(), |x| Uint4Expr::splat( )); impl_io_texel!(i32, i32, Int4, |x: Int4Expr| x.x(), |x| Int4Expr::splat(x)); impl_io_texel!(Uint2, u32, Uint4, |x: Uint4Expr| x.xy(), |x: Uint2Expr| { - make_uint4(x.x(), x.y(), 0u32, 0u32) + Uint4::expr(x.x(), x.y(), 0u32, 0u32) }); impl_io_texel!(Int2, i32, Int4, |x: Int4Expr| x.xy(), |x: Int2Expr| { - make_int4(x.x(), x.y(), 0i32, 0i32) + Int4::expr(x.x(), x.y(), 0i32, 0i32) }); impl_io_texel!(Uint4, u32, Uint4, |x: Uint4Expr| x, |x| x); impl_io_texel!(Int4, i32, Int4, |x: Int4Expr| x, |x| x); @@ -1004,20 +1006,25 @@ impl_storage_texel!([f16; 4], Byte4, f32, Float2, Float4, Int2, Int4, Uint2, Uin // `T` is the read out type of the texture, which is not necessarily the same as the storage type // In fact, the texture can be stored in any format as long as it can be converted to `T` pub struct Tex2d { + #[allow(dead_code)] pub(crate) width: u32, + #[allow(dead_code)] pub(crate) height: u32, pub(crate) handle: Arc, - pub(crate) marker: std::marker::PhantomData, + pub(crate) marker: PhantomData, } // `T` is the read out type of the texture, which is not necessarily the same as the storage type // In fact, the texture can be stored in any format as long as it can be converted to `T` pub struct Tex3d { + #[allow(dead_code)] pub(crate) width: u32, + #[allow(dead_code)] pub(crate) height: u32, + #[allow(dead_code)] pub(crate) depth: u32, pub(crate) handle: Arc, - pub(crate) marker: std::marker::PhantomData, + pub(crate) marker: PhantomData, } #[derive(Clone, Copy)] pub struct Tex2dView<'a, T: IoTexel> { @@ -1062,7 +1069,7 @@ macro_rules! impl_tex_view { data: data.as_mut_ptr() as *mut u8, }), resource_tracker: rt, - marker: std::marker::PhantomData, + marker: PhantomData, callback: None, } } @@ -1093,7 +1100,7 @@ macro_rules! impl_tex_view { data: data.as_ptr() as *const u8, }), resource_tracker: rt, - marker: std::marker::PhantomData, + marker: PhantomData, callback: None, } } @@ -1122,7 +1129,7 @@ macro_rules! impl_tex_view { buffer_offset: buffer_view.offset, }), resource_tracker: rt, - marker: std::marker::PhantomData, + marker: PhantomData, callback: None, } } @@ -1154,7 +1161,7 @@ macro_rules! impl_tex_view { buffer_offset: buffer_view.offset, }), resource_tracker: rt, - marker: std::marker::PhantomData, + marker: PhantomData, callback: None, } } @@ -1184,7 +1191,7 @@ macro_rules! impl_tex_view { dst_level: other.level, }), resource_tracker: rt, - marker: std::marker::PhantomData, + marker: PhantomData, callback: None, } } @@ -1282,7 +1289,7 @@ impl Tex3d { } #[derive(Clone)] pub struct BufferVar { - pub(crate) marker: std::marker::PhantomData, + pub(crate) marker: PhantomData, #[allow(dead_code)] pub(crate) handle: Option>, pub(crate) node: NodeRef, @@ -1305,7 +1312,7 @@ pub struct BindlessArrayVar { pub struct BindlessBufferVar { array: NodeRef, buffer_index: Expr, - _marker: std::marker::PhantomData, + _marker: PhantomData, } impl ToNode for BindlessBufferVar { fn node(&self) -> NodeRef { @@ -1332,7 +1339,7 @@ impl IndexRead for BindlessBufferVar { } impl BindlessBufferVar { pub fn len(&self) -> Expr { - let stride = const_(T::type_().size() as u64); + let stride = (T::type_().size() as u64).expr(); Expr::::from_node(__current_scope(|b| { b.call( Func::BindlessBufferSize, @@ -1373,7 +1380,7 @@ impl BindlessByteBufferVar { })) } pub fn len(&self) -> Expr { - let s = const_(1u64); + let s = (1u64).expr(); Expr::::from_node(__current_scope(|b| { b.call( Func::BindlessBufferSize, @@ -1573,10 +1580,7 @@ impl BindlessArrayVar { }; v } - pub fn byte_address_buffer( - &self, - buffer_index: impl Into>, - ) -> BindlessByteBufferVar { + pub fn byte_address_buffer(&self, buffer_index: impl Into>) -> BindlessByteBufferVar { let v = BindlessByteBufferVar { array: self.node, buffer_index: buffer_index.into(), @@ -1587,7 +1591,7 @@ impl BindlessArrayVar { let v = BindlessBufferVar { array: self.node, buffer_index: buffer_index.into(), - _marker: std::marker::PhantomData, + _marker: PhantomData, }; if __env_need_backtrace() && is_cpu_backend() { let vt = v.__type(); @@ -1718,7 +1722,7 @@ impl BufferVar { }); Self { node, - marker: std::marker::PhantomData, + marker: PhantomData, handle: Some(buffer.buffer.handle.clone()), } } @@ -1911,7 +1915,7 @@ pub struct Tex2dVar { pub(crate) node: NodeRef, #[allow(dead_code)] pub(crate) handle: Option>, - pub(crate) marker: std::marker::PhantomData, + pub(crate) marker: PhantomData, #[allow(dead_code)] pub(crate) level: Option, } @@ -1943,7 +1947,7 @@ impl Tex2dVar { node, handle: Some(view.tex.handle.clone()), level: Some(view.level), - marker: std::marker::PhantomData, + marker: PhantomData, } } pub fn read(&self, uv: impl Into>) -> Expr { @@ -1997,7 +2001,7 @@ impl Tex3dVar { node, handle: Some(view.tex.handle.clone()), level: Some(view.level), - marker: std::marker::PhantomData, + marker: PhantomData, } } pub fn read(&self, uv: impl Into>) -> Expr { @@ -2028,7 +2032,7 @@ pub struct Tex3dVar { pub(crate) node: NodeRef, #[allow(dead_code)] pub(crate) handle: Option>, - pub(crate) marker: std::marker::PhantomData, + pub(crate) marker: PhantomData, #[allow(dead_code)] pub(crate) level: Option, } diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 02079c6..cda8929 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -1,12 +1,23 @@ -use std::{collections::HashMap, marker::PhantomData, sync::Arc}; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::Arc; -use crate::{runtime::submit_default_stream_and_sync, ResourceTracker, *}; -use api::AccelBuildRequest; -use luisa_compute_api_types as api; -use luisa_compute_derive::__Value; -use luisa_compute_ir::ir::{new_node, AccelBinding, Binding, Func, Instruction, IrBuilder, Node}; +use crate::internal_prelude::*; + +use crate::runtime::*; +use crate::{ResourceTracker, *}; +use luisa_compute_ir::ir::{ + new_node, AccelBinding, Binding, Func, Instruction, IrBuilder, Node, NodeRef, Type, +}; use parking_lot::RwLock; use std::ops::Deref; + +pub use api::{ + AccelBuildModificationFlags, AccelBuildRequest, AccelOption, AccelUsageHint, MeshType, + PixelFormat, PixelStorage, +}; +use luisa_compute_api_types as api; + pub(crate) struct AccelHandle { pub(crate) device: Device, pub(crate) handle: api::Accel, @@ -41,6 +52,7 @@ pub(crate) struct ProceduralPrimitiveHandle { pub(crate) device: Device, pub(crate) handle: api::ProceduralPrimitive, pub(crate) native_handle: *mut std::ffi::c_void, + #[allow(dead_code)] pub(crate) aabb_buffer: Arc, } impl Drop for ProceduralPrimitiveHandle { @@ -85,7 +97,9 @@ pub(crate) struct MeshHandle { pub(crate) device: Device, pub(crate) handle: api::Mesh, pub(crate) native_handle: *mut std::ffi::c_void, + #[allow(dead_code)] pub(crate) vbuffer: Arc, + #[allow(dead_code)] pub(crate) ibuffer: Arc, } impl Drop for MeshHandle { @@ -148,7 +162,8 @@ impl Accel { let mut flags = api::AccelBuildModificationFlags::PRIMITIVE | AccelBuildModificationFlags::TRANSFORM; - flags |= api::AccelBuildModificationFlags::VISIBILITY | api::AccelBuildModificationFlags::USER_ID; + flags |= api::AccelBuildModificationFlags::VISIBILITY + | api::AccelBuildModificationFlags::USER_ID; if opaque { flags |= api::AccelBuildModificationFlags::OPAQUE_ON; @@ -183,7 +198,8 @@ impl Accel { ) { let mut flags = api::AccelBuildModificationFlags::PRIMITIVE; dbg!(flags); - flags |= api::AccelBuildModificationFlags::VISIBILITY | api::AccelBuildModificationFlags::USER_ID; + flags |= api::AccelBuildModificationFlags::VISIBILITY + | api::AccelBuildModificationFlags::USER_ID; if opaque { flags |= api::AccelBuildModificationFlags::OPAQUE_ON; @@ -312,7 +328,7 @@ pub struct AccelVar { #[repr(C)] #[repr(align(16))] -#[derive(Clone, Copy, __Value, Debug)] +#[derive(Clone, Copy, Value, Debug)] pub struct Ray { pub orig: PackedFloat3, pub tmin: f32, @@ -320,14 +336,14 @@ pub struct Ray { pub tmax: f32, } #[repr(C)] -#[derive(Clone, Copy, __Value, Debug)] +#[derive(Clone, Copy, Value, Debug)] pub struct Aabb { pub min: PackedFloat3, pub max: PackedFloat3, } #[repr(C)] -#[derive(Clone, Copy, __Value, Debug)] +#[derive(Clone, Copy, Value, Debug)] pub struct TriangleHit { pub inst: u32, pub prim: u32, @@ -336,14 +352,14 @@ pub struct TriangleHit { } #[repr(C)] -#[derive(Clone, Copy, __Value, Debug)] +#[derive(Clone, Copy, Value, Debug)] pub struct ProceduralHit { pub inst: u32, pub prim: u32, } #[repr(C)] -#[derive(Clone, Copy, __Value, Debug)] +#[derive(Clone, Copy, Value, Debug)] pub struct CommittedHit { pub inst_id: u32, pub prim_id: u32, @@ -372,8 +388,8 @@ pub enum HitType { pub fn offset_ray_origin(p: Expr, n: Expr) -> Expr { lazy_static! { - static ref F: Callable, Expr)-> Expr> = - create_static_callable::, Expr)->Expr>(|p, n| { + static ref F: Callable, Expr) -> Expr> = + create_static_callable::, Expr) -> Expr>(|p, n| { const ORIGIN: f32 = 1.0f32 / 32.0f32; const FLOAT_SCALE: f32 = 1.0f32 / 65536.0f32; const INT_SCALE: f32 = 256.0f32; @@ -392,7 +408,7 @@ pub type Index = PackedUint3; #[repr(C)] #[repr(align(8))] -#[derive(Clone, Copy, __Value, Debug)] +#[derive(Clone, Copy, Value, Debug)] pub struct Hit { pub inst_id: u32, pub prim_id: u32, diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 5f12324..90dc02b 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1,28 +1,38 @@ -use crate::backend::Backend; -use crate::lang::KernelBuildOptions; -use crate::rtx::ProceduralPrimitiveHandle; -use crate::*; -use crate::{lang::Value, resource::*}; - -use api::AccelOption; -use lang::{KernelBuildFn, KernelBuilder, KernelParameter, KernelSignature}; -pub use luisa_compute_api_types as api; -use luisa_compute_backend::proxy::ProxyBackend; -use luisa_compute_ir::ir::{self, KernelModule}; -use luisa_compute_ir::CArc; -use parking_lot::lock_api::RawMutex as RawMutexTrait; -use parking_lot::{Condvar, Mutex, RawMutex, RwLock}; -use raw_window_handle::HasRawWindowHandle; -use rtx::{Accel, Mesh, MeshHandle}; +use std::any::Any; use std::cell::{Cell, RefCell}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use std::env; use std::ffi::CString; use std::hash::Hash; use std::ops::Deref; use std::path::PathBuf; -use std::sync::Arc; +use std::rc::Rc; +use std::sync::{Arc, Weak}; + +use parking_lot::lock_api::RawMutex as RawMutexTrait; +use parking_lot::{Condvar, Mutex, RawMutex, RwLock}; + +use raw_window_handle::HasRawWindowHandle; use winit::window::Window; +use crate::internal_prelude::*; +use ir::{ + CallableModule, CallableModuleRef, Capture, CpuCustomOp, KernelModule, Module, ModuleFlags, + ModuleKind, ModulePools, +}; + +use crate::backend::Backend; +use crate::rtx; +use crate::rtx::{Accel, Mesh, MeshHandle, ProceduralPrimitiveHandle}; + +use api::AccelOption; +pub use luisa_compute_api_types as api; +use luisa_compute_backend::proxy::ProxyBackend; + +mod kernel; + +pub use kernel::*; + #[derive(Clone)] pub struct Device { pub(crate) inner: Arc, @@ -60,7 +70,7 @@ pub(crate) struct DeviceHandle { pub(crate) backend: ProxyBackend, pub(crate) default_stream: Option>, #[allow(dead_code)] - pub(crate) ctx: Arc, + pub(crate) ctx: Arc, } unsafe impl Send for DeviceHandle {} @@ -183,7 +193,7 @@ impl Device { handle: api::Buffer(buffer.resource.handle), native_handle: buffer.resource.native_handle, }), - _marker: std::marker::PhantomData {}, + _marker: PhantomData {}, len: count, }; buffer @@ -206,7 +216,7 @@ impl Device { let array = self.create_bindless_array(slots); BufferHeap { inner: array, - _marker: std::marker::PhantomData {}, + _marker: PhantomData {}, } } pub fn create_bindless_array(&self, slots: usize) -> BindlessArray { @@ -259,7 +269,7 @@ impl Device { width, height, handle, - marker: std::marker::PhantomData {}, + marker: PhantomData {}, }; tex } @@ -291,7 +301,7 @@ impl Device { height, depth, handle, - marker: std::marker::PhantomData {}, + marker: PhantomData {}, }; tex } @@ -648,7 +658,7 @@ impl Drop for StreamHandle { pub struct Scope<'a> { handle: Arc, - marker: std::marker::PhantomData<&'a ()>, + marker: PhantomData<&'a ()>, synchronized: Cell, resource_tracker: RefCell, } @@ -674,7 +684,7 @@ impl<'a> Scope<'a> { #[inline] fn command_list(&self) -> CommandList<'a> { CommandList::<'a> { - marker: std::marker::PhantomData {}, + marker: PhantomData {}, commands: Vec::new(), } } @@ -820,7 +830,7 @@ impl Stream { self.handle.lock(); Scope { handle: self.handle.clone(), - marker: std::marker::PhantomData {}, + marker: PhantomData {}, synchronized: Cell::new(false), resource_tracker: RefCell::new(ResourceTracker::new()), } @@ -836,7 +846,7 @@ impl Stream { } pub(crate) struct CommandList<'a> { - marker: std::marker::PhantomData<&'a ()>, + marker: PhantomData<&'a ()>, commands: Vec>, } @@ -850,6 +860,7 @@ impl<'a> CommandList<'a> { pub fn extend>>(&mut self, commands: I) { self.commands.extend(commands); } + #[allow(dead_code)] pub fn push(&mut self, command: Command<'a>) { self.commands.push(command); } @@ -870,7 +881,7 @@ pub struct Command<'a> { #[allow(dead_code)] pub(crate) inner: api::Command, // is this really necessary? - pub(crate) marker: std::marker::PhantomData<&'a ()>, + pub(crate) marker: PhantomData<&'a ()>, pub(crate) callback: Option>, #[allow(dead_code)] pub(crate) resource_tracker: ResourceTracker, @@ -879,6 +890,7 @@ pub struct Command<'a> { pub(crate) struct AsyncShaderArtifact { shader: Option, // strange naming, huh? + #[allow(dead_code)] name: Arc, } @@ -1163,7 +1175,7 @@ impl RawKernel { args_count: args.len(), dispatch_size, }), - marker: std::marker::PhantomData, + marker: PhantomData, resource_tracker: rt, callback: None, } @@ -1176,7 +1188,7 @@ impl RawKernel { pub struct Callable> { #[allow(dead_code)] pub(crate) inner: RawCallable, - pub(crate) _marker: std::marker::PhantomData, + pub(crate) _marker: PhantomData, } pub(crate) struct DynCallableInner> { builder: Box, &mut KernelBuilder) -> Callable>, @@ -1217,8 +1229,8 @@ impl> DynCallable { { let callables = &mut inner.callables; for c in callables { - if lang::__check_callable(&c.inner.module, nodes) { - return CallableRet::_from_return(lang::__invoke_callable( + if crate::lang::__check_callable(&c.inner.module, nodes) { + return CallableRet::_from_return(crate::lang::__invoke_callable( &c.inner.module, nodes, )); @@ -1243,12 +1255,12 @@ impl> DynCallable { *r.borrow_mut() = r_backup; }); assert!( - lang::__check_callable(&new_callable.inner.module, nodes), + crate::lang::__check_callable(&new_callable.inner.module, nodes), "Callable builder returned a callable that does not match the arguments" ); let callables = &mut inner.callables; callables.push(new_callable); - CallableRet::_from_return(lang::__invoke_callable( + CallableRet::_from_return(crate::lang::__invoke_callable( &callables.last().unwrap().inner.module, nodes, )) @@ -1264,7 +1276,7 @@ pub struct RawCallable { pub struct Kernel> { pub(crate) inner: RawKernel, - pub(crate) _marker: std::marker::PhantomData, + pub(crate) _marker: PhantomData, } unsafe impl> Send for Kernel {} unsafe impl> Sync for Kernel {} @@ -1327,7 +1339,7 @@ macro_rules! impl_call_for_callable { $first.encode(&mut encoder); $($rest.encode(&mut encoder);)* CallableRet::_from_return( - lang::__invoke_callable(&self.inner.module, &encoder.args)) + crate::lang::__invoke_callable(&self.inner.module, &encoder.args)) } } impl DynCallableR> { @@ -1345,7 +1357,7 @@ macro_rules! impl_call_for_callable { impl CallableR> { pub fn call(&self)->R { CallableRet::_from_return( - lang::__invoke_callable(&self.inner.module, &[])) + crate::lang::__invoke_callable(&self.inner.module, &[])) } } impl DynCallableR> { diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs new file mode 100644 index 0000000..1bee826 --- /dev/null +++ b/luisa_compute/src/runtime/kernel.rs @@ -0,0 +1,733 @@ +use super::*; + +#[macro_export] +macro_rules! impl_callable_param { + ($t:ty, $e:ty, $v:ty) => { + impl CallableParameter for $e { + fn def_param( + _: Option>, + builder: &mut KernelBuilder, + ) -> Self { + builder.value::<$t>() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.var(*self) + } + } + impl CallableParameter for $v { + fn def_param( + _: Option>, + builder: &mut KernelBuilder, + ) -> Self { + builder.var::<$t>() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.var(*self) + } + } + }; +} + +// Not recommended to use this directly +pub struct KernelBuilder { + device: Option, + args: Vec, +} + +pub trait CallableParameter: Sized + Clone + 'static { + fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self; + fn encode(&self, encoder: &mut CallableArgEncoder); +} +macro_rules! impl_callable_parameter_for_tuple { + ()=>{ + impl CallableParameter for () { + fn def_param(_: Option>, _: &mut KernelBuilder) {} + fn encode(&self, _: &mut CallableArgEncoder) { } + } + }; + ($first:ident $($rest:ident) *) => { + impl<$first:CallableParameter, $($rest: CallableParameter),*> CallableParameter for ($first, $($rest,)*) { + #[allow(non_snake_case)] + fn def_param(arg: Option>, builder: &mut KernelBuilder) -> Self { + if let Some(arg) = arg { + let ($first, $($rest,)*) = arg.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + let $first = $first::def_param(Some(std::rc::Rc::new($first)), builder); + let ($($rest,)*) = ($($rest::def_param(Some(std::rc::Rc::new($rest)), builder),)*); + ($first, $($rest,)*) + }else { + let $first = $first::def_param(None, builder); + let ($($rest,)*) = ($($rest::def_param(None, builder),)*); + ($first, $($rest,)*) + } + } + #[allow(non_snake_case)] + fn encode(&self, encoder: &mut CallableArgEncoder) { + let ($first, $($rest,)*) = self; + $first.encode(encoder); + $($rest.encode(encoder);)* + } + } + impl_callable_parameter_for_tuple!($($rest)*); + }; + +} +impl_callable_parameter_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +impl CallableParameter for BufferVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.buffer() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.buffer(self) + } +} +impl CallableParameter for ByteBufferVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.byte_buffer() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.byte_buffer(self) + } +} +impl CallableParameter for Tex2dVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.tex2d() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.tex2d(self) + } +} + +impl CallableParameter for Tex3dVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.tex3d() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.tex3d(self) + } +} + +impl CallableParameter for BindlessArrayVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.bindless_array() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.bindless_array(self) + } +} + +impl CallableParameter for rtx::AccelVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.accel() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.accel(self) + } +} + +pub trait KernelParameter { + fn def_param(builder: &mut KernelBuilder) -> Self; +} + +impl KernelParameter for U +where + U: ExprProxy, + T: Value, +{ + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.uniform::() + } +} +impl KernelParameter for ByteBufferVar { + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.byte_buffer() + } +} +impl KernelParameter for BufferVar { + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.buffer() + } +} + +impl KernelParameter for Tex2dVar { + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.tex2d() + } +} + +impl KernelParameter for Tex3dVar { + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.tex3d() + } +} + +impl KernelParameter for BindlessArrayVar { + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.bindless_array() + } +} + +impl KernelParameter for rtx::AccelVar { + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.accel() + } +} +macro_rules! impl_kernel_param_for_tuple { + ($first:ident $($rest:ident)*) => { + impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelParameter for ($first, $($rest,)*) { + #[allow(non_snake_case)] + fn def_param(builder: &mut KernelBuilder) -> Self { + ($first::def_param(builder), $($rest::def_param(builder)),*) + } + } + impl_kernel_param_for_tuple!($($rest)*); + }; + ()=>{ + impl KernelParameter for () { + fn def_param(_: &mut KernelBuilder) -> Self { + } + } + } +} +impl_kernel_param_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); +impl KernelBuilder { + pub fn new(device: Option, is_kernel: bool) -> Self { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!(!r.lock, "Cannot record multiple kernels at the same time"); + assert!( + r.scopes.is_empty(), + "Cannot record multiple kernels at the same time" + ); + r.lock = true; + r.device = device.as_ref().map(|d| WeakDevice::new(d)); + r.pools = Some(CArc::new(ModulePools::new())); + r.scopes.clear(); + r.building_kernel = is_kernel; + let pools = r.pools.clone().unwrap(); + r.scopes.push(IrBuilder::new(pools)); + }); + Self { + device, + args: vec![], + } + } + pub(crate) fn arg(&mut self, ty: CArc, by_value: bool) -> NodeRef { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Argument { by_value }), ty), + ); + self.args.push(node); + node + } + pub fn value(&mut self) -> Expr { + let node = self.arg(T::type_(), true); + FromNode::from_node(node) + } + pub fn var(&mut self) -> Var { + let node = self.arg(T::type_(), false); + FromNode::from_node(node) + } + pub fn uniform(&mut self) -> Expr { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Uniform), T::type_()), + ); + self.args.push(node); + FromNode::from_node(node) + } + pub fn byte_buffer(&mut self) -> ByteBufferVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Buffer), Type::void()), + ); + self.args.push(node); + ByteBufferVar { node, handle: None } + } + pub fn buffer(&mut self) -> BufferVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Buffer), T::type_()), + ); + self.args.push(node); + BufferVar { + node, + marker: PhantomData, + handle: None, + } + } + pub fn tex2d(&mut self) -> Tex2dVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Texture2D), T::type_()), + ); + self.args.push(node); + Tex2dVar { + node, + marker: PhantomData, + handle: None, + level: None, + } + } + pub fn tex3d(&mut self) -> Tex3dVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Texture3D), T::type_()), + ); + self.args.push(node); + Tex3dVar { + node, + marker: PhantomData, + handle: None, + level: None, + } + } + pub fn bindless_array(&mut self) -> BindlessArrayVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Bindless), Type::void()), + ); + self.args.push(node); + BindlessArrayVar { node, handle: None } + } + pub fn accel(&mut self) -> rtx::AccelVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Accel), Type::void()), + ); + self.args.push(node); + rtx::AccelVar { node, handle: None } + } + fn collect_module_info(&self) -> (ResourceTracker, Vec>, Vec) { + RECORDER.with(|r| { + let mut resource_tracker = ResourceTracker::new(); + let r = r.borrow_mut(); + let mut captured: Vec = Vec::new(); + let mut captured_buffers: Vec<_> = r.captured_buffer.values().cloned().collect(); + captured_buffers.sort_by_key(|(i, _, _, _)| *i); + for (j, (i, node, binding, handle)) in captured_buffers.into_iter().enumerate() { + assert_eq!(j, i); + captured.push(Capture { node, binding }); + resource_tracker.add_any(handle); + } + let mut cpu_custom_ops: Vec<_> = r.cpu_custom_ops.values().cloned().collect(); + cpu_custom_ops.sort_by_key(|(i, _)| *i); + let mut cpu_custom_ops: Vec> = cpu_custom_ops + .iter() + .enumerate() + .map(|(j, (i, op))| { + assert_eq!(j, *i); + (*op).clone() + }) + .collect::>(); + let callables: Vec = r.callables.values().cloned().collect(); + let mut captured_set = HashSet::::new(); + let mut cpu_custom_ops_set = HashSet::::new(); + let mut callable_set = HashSet::::new(); + for capture in captured.iter() { + captured_set.insert(*capture); + } + for op in &cpu_custom_ops { + cpu_custom_ops_set.insert(CArc::as_ptr(op) as u64); + } + for c in &callables { + callable_set.insert(CArc::as_ptr(&c.0) as u64); + for capture in c.0.captures.as_ref() { + if !captured_set.contains(capture) { + captured_set.insert(*capture); + captured.push(*capture); + } + } + for op in c.0.cpu_custom_ops.as_ref() { + let id = CArc::as_ptr(op) as u64; + if !cpu_custom_ops_set.contains(&id) { + cpu_custom_ops_set.insert(id); + cpu_custom_ops.push(op.clone()); + } + } + } + (resource_tracker, cpu_custom_ops, captured) + }) + } + fn build_callable(&mut self, body: impl FnOnce(&mut Self) -> R) -> RawCallable { + let ret = body(self); + let ret_type = ret._return(); + let (rt, cpu_custom_ops, captures) = self.collect_module_info(); + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!(r.lock); + r.lock = false; + if let Some(t) = &r.callable_ret_type { + assert!( + luisa_compute_ir::context::is_type_equal(t, &ret_type), + "Return type mismatch" + ); + } else { + r.callable_ret_type = Some(ret_type.clone()); + } + assert_eq!(r.scopes.len(), 1); + let scope = r.scopes.pop().unwrap(); + let entry = scope.finish(); + let ir_module = Module { + entry, + kind: ModuleKind::Kernel, + pools: r.pools.clone().unwrap(), + flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM + | ModuleFlags::REQUIRES_FWD_AD_TRANSFORM, + }; + let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); + let module = CallableModule { + module: ir_module, + ret_type, + cpu_custom_ops: CBoxedSlice::new(cpu_custom_ops), + captures: CBoxedSlice::new(captures), + args: CBoxedSlice::new(self.args.clone()), + pools: r.pools.clone().unwrap(), + }; + let module = CallableModuleRef(CArc::new(module)); + r.reset(); + RawCallable { + module, + resource_tracker: rt, + } + }) + } + fn build_kernel( + &mut self, + options: KernelBuildOptions, + body: impl FnOnce(&mut Self), + ) -> crate::runtime::RawKernel { + body(self); + let (rt, cpu_custom_ops, captures) = self.collect_module_info(); + RECORDER.with(|r| -> crate::runtime::RawKernel { + let mut r = r.borrow_mut(); + assert!(r.lock); + r.lock = false; + assert_eq!(r.scopes.len(), 1); + let scope = r.scopes.pop().unwrap(); + let entry = scope.finish(); + + let ir_module = Module { + entry, + kind: ModuleKind::Kernel, + pools: r.pools.clone().unwrap(), + flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM + | ModuleFlags::REQUIRES_FWD_AD_TRANSFORM, + }; + let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); + let module = KernelModule { + module: ir_module, + cpu_custom_ops: CBoxedSlice::new(cpu_custom_ops), + captures: CBoxedSlice::new(captures), + shared: CBoxedSlice::new(r.shared.clone()), + args: CBoxedSlice::new(self.args.clone()), + block_size: r.block_size.unwrap_or([64, 1, 1]), + pools: r.pools.clone().unwrap(), + }; + + let module = CArc::new(module); + let name = options.name.unwrap_or("".to_string()); + let name = Arc::new(CString::new(name).unwrap()); + let shader_options = api::ShaderOption { + enable_cache: options.enable_cache, + enable_fast_math: options.enable_fast_math, + enable_debug_info: options.enable_debug_info, + compile_only: false, + name: name.as_ptr(), + }; + let artifact = if options.async_compile { + ShaderArtifact::Async(AsyncShaderArtifact::new( + self.device.clone().unwrap(), + module.clone(), + shader_options, + name, + )) + } else { + ShaderArtifact::Sync( + self.device + .as_ref() + .unwrap() + .inner + .create_shader(&module, &shader_options), + ) + }; + // + r.reset(); + RawKernel { + artifact, + device: self.device.clone().unwrap(), + resource_tracker: rt, + module, + } + }) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct KernelBuildOptions { + pub enable_debug_info: bool, + pub enable_optimization: bool, + pub async_compile: bool, + pub enable_cache: bool, + pub enable_fast_math: bool, + pub name: Option, +} + +impl Default for KernelBuildOptions { + fn default() -> Self { + let enable_debug_info = match env::var("LUISA_DEBUG") { + Ok(s) => s == "1", + Err(_) => false, + }; + Self { + enable_debug_info, + enable_optimization: true, + async_compile: false, + enable_cache: true, + enable_fast_math: true, + name: None, + } + } +} + +pub trait KernelBuildFn { + fn build_kernel( + &self, + builder: &mut KernelBuilder, + options: KernelBuildOptions, + ) -> crate::runtime::RawKernel; +} + +pub trait CallableBuildFn { + fn build_callable(&self, args: Option>, builder: &mut KernelBuilder) + -> RawCallable; +} + +pub trait StaticCallableBuildFn: CallableBuildFn {} + +// @FIXME: this looks redundant +pub unsafe trait CallableRet { + fn _return(&self) -> CArc; + fn _from_return(node: NodeRef) -> Self; +} + +unsafe impl CallableRet for () { + fn _return(&self) -> CArc { + Type::void() + } + fn _from_return(_: NodeRef) -> Self {} +} + +unsafe impl CallableRet for T { + fn _return(&self) -> CArc { + __current_scope(|b| { + b.return_(self.node()); + }); + T::Value::type_() + } + fn _from_return(node: NodeRef) -> Self { + Self::from_node(node) + } +} + +pub trait CallableSignature<'a> { + type Callable; + type DynCallable; + type Fn: CallableBuildFn; + type StaticFn: StaticCallableBuildFn; + type DynFn: CallableBuildFn + 'static; + type Ret: CallableRet; + fn wrap_raw_callable(callable: RawCallable) -> Self::Callable; + fn create_dyn_callable(device: Device, init_once: bool, f: Self::DynFn) -> Self::DynCallable; +} + +pub trait KernelSignature<'a> { + type Fn: KernelBuildFn; + type Kernel; + + fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel; +} +macro_rules! impl_callable_signature { + ()=>{ + impl<'a, R: CallableRet +'static> CallableSignature<'a> for fn()->R { + type Fn = &'a dyn Fn() ->R; + type DynFn = BoxR>; + type StaticFn = fn() -> R; + type Callable = CallableR>; + type DynCallable = DynCallableR>; + type Ret = R; + fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ + Callable { + inner: callable, + _marker:PhantomData, + } + } + fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { + DynCallable::new(device, init_once, Box::new(move |arg, builder| { + let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); + Self::wrap_raw_callable(raw_callable) + })) + } + } + }; + ($first:ident $($rest:ident)*) => { + impl<'a, R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature<'a> for fn($first, $($rest,)*)->R { + type Fn = &'a dyn Fn($first, $($rest),*)->R; + type DynFn = BoxR>; + type Callable = CallableR>; + type StaticFn = fn($first, $($rest,)*)->R; + type DynCallable = DynCallableR>; + type Ret = R; + fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ + Callable { + inner: callable, + _marker:PhantomData, + } + } + fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { + DynCallable::new(device, init_once, Box::new(move |arg, builder| { + let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); + Self::wrap_raw_callable(raw_callable) + })) + } + } + impl_callable_signature!($($rest)*); + }; +} +impl_callable_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); +macro_rules! impl_kernel_signature { + ()=>{ + impl<'a> KernelSignature<'a> for fn() { + type Fn = &'a dyn Fn(); + type Kernel = Kernel; + fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { + Self::Kernel{ + inner:kernel, + _marker:PhantomData, + } + } + } + }; + ($first:ident $($rest:ident)*) => { + impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature<'a> for fn($first, $($rest,)*) { + type Fn = &'a dyn Fn($first::Parameter, $($rest::Parameter),*); + type Kernel = Kernel; + fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { + Self::Kernel{ + inner:kernel, + _marker:PhantomData, + } + } + } + impl_kernel_signature!($($rest)*); + }; +} +impl_kernel_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +macro_rules! impl_callable_build_for_fn { + ()=>{ + impl CallableBuildFn for &dyn Fn()->R { + fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { + builder.build_callable( |_| { + self() + }) + } + } + impl CallableBuildFn for fn()->R { + fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { + builder.build_callable( |_| { + self() + }) + } + } + impl CallableBuildFn for BoxR> { + fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { + builder.build_callable( |_| { + self() + }) + } + } + impl StaticCallableBuildFn for fn()->R {} + }; + ($first:ident $($rest:ident)*) => { + impl CallableBuildFn for &dyn Fn($first, $($rest,)*)->R { + #[allow(non_snake_case)] + fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { + builder.build_callable( |builder| { + if let Some(args) = args { + let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + let $first = $first::def_param(Some(Rc::new($first)), builder); + $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* + self($first, $($rest,)*) + } else { + let $first = $first::def_param(None, builder); + $(let $rest = $rest::def_param(None, builder);)* + self($first, $($rest,)*) + } + }) + } + } + impl CallableBuildFn for BoxR> { + #[allow(non_snake_case)] + fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { + builder.build_callable( |builder| { + if let Some(args) = args { + let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + let $first = $first::def_param(Some(Rc::new($first)), builder); + $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* + self($first, $($rest,)*) + } else { + let $first = $first::def_param(None, builder); + $(let $rest = $rest::def_param(None, builder);)* + self($first, $($rest,)*) + } + }) + } + } + impl CallableBuildFn for fn($first, $($rest,)*)->R { + #[allow(non_snake_case)] + fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { + builder.build_callable( |builder| { + if let Some(args) = args { + let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + let $first = $first::def_param(Some(Rc::new($first)), builder); + $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* + self($first, $($rest,)*) + } else { + let $first = $first::def_param(None, builder); + $(let $rest = $rest::def_param(None, builder);)* + self($first, $($rest,)*) + } + }) + } + } + impl StaticCallableBuildFn for fn($first, $($rest,)*)->R {} + impl_callable_build_for_fn!($($rest)*); + }; +} +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); +macro_rules! impl_kernel_build_for_fn { + ()=>{ + impl KernelBuildFn for &dyn Fn() { + fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { + builder.build_kernel(options, |_| { + self() + }) + } + } + }; + ($first:ident $($rest:ident)*) => { + impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for &dyn Fn($first, $($rest,)*) { + #[allow(non_snake_case)] + fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { + builder.build_kernel(options, |builder| { + let $first = $first::def_param(builder); + $(let $rest = $rest::def_param(builder);)* + self($first, $($rest,)*) + }) + } + } + impl_kernel_build_for_fn!($($rest)*); + }; +} +impl_kernel_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 510c35d..97722bf 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -1,13 +1,13 @@ use std::ops::Range; +use luisa::lang::diff::*; +use luisa::lang::types::core::*; +use luisa::lang::types::vector::*; use luisa::prelude::*; -use luisa::*; use luisa_compute as luisa; use rand::prelude::*; -use rayon::{ - prelude::{IntoParallelIterator, ParallelIterator}, - slice::ParallelSliceMut, -}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rayon::slice::ParallelSliceMut; #[path = "common.rs"] mod common; use common::*; @@ -227,7 +227,7 @@ struct Foo { } autodiff_2!(autodiff_const, 1.0..10.0, |x: Float, y: Float| { - let k = 2.0 / const_::(3.0); + let k = 2.0 / 3.0_f32.expr(); x * k + y * k }); autodiff_2!(autodiff_struct, 1.0..10.0, |x: Float, y: Float| { @@ -266,7 +266,7 @@ fn autodiff_vec3_reduce_add_manual() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.x() + v.y() + v.z() }); } @@ -277,7 +277,7 @@ fn autodiff_vec3_reduce_prod_manual() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.x() * v.y() * v.z() }); } @@ -287,7 +287,7 @@ fn autodiff_vec3_reduce_add() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.reduce_sum() }); } @@ -297,7 +297,7 @@ fn autodiff_vec3_reduce_mul() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.reduce_prod() }); } @@ -307,7 +307,7 @@ fn autodiff_vec3_dot() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.dot(v) }); } @@ -317,7 +317,7 @@ fn autodiff_vec3_length() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.length() }); } @@ -327,7 +327,7 @@ fn autodiff_vec3_length_squared() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.length_squared() }); } @@ -337,21 +337,21 @@ fn autodiff_vec3_normalize() { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.normalize().x() }); autodiff_helper(-10.0..10.0, 1024 * 1024, 3, |inputs| { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.normalize().y() }); autodiff_helper(-10.0..10.0, 1024 * 1024, 3, |inputs| { let x = inputs[0]; let y = inputs[1]; let z = inputs[2]; - let v = make_float3(x, y, z); + let v = Float3::expr(x, y, z); v.normalize().z() }); } @@ -362,12 +362,12 @@ fn autodiff_vec3_cross_x() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = def(make_float3(ax, ay, az)); + let a = Float3::expr(ax, ay, az).var(); let bx = inputs[3]; let by = inputs[4]; let bz = inputs[5]; - let b = def(make_float3(bx, by, bz)); - let v = def(a.cross(*b)); + let b = Float3::expr(bx, by, bz).var(); + let v = a.cross(*b).var(); *v.x() }); } @@ -377,12 +377,12 @@ fn autodiff_vec3_cross_y() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = def(make_float3(ax, ay, az)); + let a = Float3::expr(ax, ay, az).var(); let bx = inputs[3]; let by = inputs[4]; let bz = inputs[5]; - let b = def(make_float3(bx, by, bz)); - let v = def(a.cross(*b)); + let b = Float3::expr(bx, by, bz).var(); + let v = a.cross(*b).var(); *v.x() }); } @@ -393,11 +393,11 @@ fn autodiff_vec3_cross_z() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[3]; let by = inputs[4]; let bz = inputs[5]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); let v = a.cross(b); v.z() }); @@ -408,11 +408,11 @@ fn autodiff_vec3_distance() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[3]; let by = inputs[4]; let bz = inputs[5]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); a.distance(b) }); } @@ -422,7 +422,7 @@ fn autodiff_vec3_replace() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let b = inputs[3]; let c = a.set_y(b); a.dot(c) @@ -434,19 +434,19 @@ fn autodiff_matmul() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[0 + 3]; let by = inputs[1 + 3]; let bz = inputs[2 + 3]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); let cx = inputs[0 + 6]; let cy = inputs[1 + 6]; let cz = inputs[2 + 6]; - let c = make_float3(cx, cy, cz); + let c = Float3::expr(cx, cy, cz); let dx = inputs[0 + 9]; let dy = inputs[1 + 9]; let dz = inputs[2 + 9]; - let d = make_float3(dx, dy, dz); + let d = Float3::expr(dx, dy, dz); let m = Mat3Expr::new(a, b, c); let o = m * d; o.x() @@ -458,19 +458,19 @@ fn autodiff_matmul_transpose() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[0 + 3]; let by = inputs[1 + 3]; let bz = inputs[2 + 3]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); let cx = inputs[0 + 6]; let cy = inputs[1 + 6]; let cz = inputs[2 + 6]; - let c = make_float3(cx, cy, cz); + let c = Float3::expr(cx, cy, cz); let dx = inputs[0 + 9]; let dy = inputs[1 + 9]; let dz = inputs[2 + 9]; - let d = make_float3(dx, dy, dz); + let d = Float3::expr(dx, dy, dz); let m = Mat3Expr::new(a, b, c); let o = m.transpose() * d; o.y() @@ -482,19 +482,19 @@ fn autodiff_matmul_2() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[0 + 3]; let by = inputs[1 + 3]; let bz = inputs[2 + 3]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); let cx = inputs[0 + 6]; let cy = inputs[1 + 6]; let cz = inputs[2 + 6]; - let c = make_float3(cx, cy, cz); + let c = Float3::expr(cx, cy, cz); let dx = inputs[0 + 9]; let dy = inputs[1 + 9]; let dz = inputs[2 + 9]; - let d = make_float3(dx, dy, dz); + let d = Float3::expr(dx, dy, dz); let m = Mat3Expr::new(a, b, c); let o = m * m * d; o.z() @@ -506,19 +506,19 @@ fn autodiff_matmul_4() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[0 + 3]; let by = inputs[1 + 3]; let bz = inputs[2 + 3]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); let cx = inputs[0 + 6]; let cy = inputs[1 + 6]; let cz = inputs[2 + 6]; - let c = make_float3(cx, cy, cz); + let c = Float3::expr(cx, cy, cz); let dx = inputs[0 + 9]; let dy = inputs[1 + 9]; let dz = inputs[2 + 9]; - let d = make_float3(dx, dy, dz); + let d = Float3::expr(dx, dy, dz); let m = Mat3Expr::new(a, b, c); let o = (m * m) * d; o.z() @@ -530,19 +530,19 @@ fn autodiff_matmul_5() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[0 + 3]; let by = inputs[1 + 3]; let bz = inputs[2 + 3]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); let cx = inputs[0 + 6]; let cy = inputs[1 + 6]; let cz = inputs[2 + 6]; - let c = make_float3(cx, cy, cz); + let c = Float3::expr(cx, cy, cz); let dx = inputs[0 + 9]; let dy = inputs[1 + 9]; let dz = inputs[2 + 9]; - let d = make_float3(dx, dy, dz); + let d = Float3::expr(dx, dy, dz); let m = Mat3Expr::new(a, b, c); let o = m.comp_mul(m) * d; o.z() @@ -554,15 +554,15 @@ fn autodiff_mat_det() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = Float3::expr(ax, ay, az); let bx = inputs[0 + 3]; let by = inputs[1 + 3]; let bz = inputs[2 + 3]; - let b = make_float3(bx, by, bz); + let b = Float3::expr(bx, by, bz); let cx = inputs[0 + 6]; let cy = inputs[1 + 6]; let cz = inputs[2 + 6]; - let c = make_float3(cx, cy, cz); + let c = Float3::expr(cx, cy, cz); let m = Mat3Expr::new(a, b, c); m.determinant() }); @@ -574,7 +574,7 @@ fn autodiff_mat_det() { // let x = inputs[0]; // let y = inputs[1]; // let z = inputs[2]; -// let v = make_float3(x, y, z); +// let v = Float3::expr(x, y, z); // v.reduce_min() // }); // } @@ -586,7 +586,7 @@ fn autodiff_mat_det() { // let x = inputs[0]; // let y = inputs[1]; // let z = inputs[2]; -// let v = make_float3(x, y, z); +// let v = Float3::expr(x, y, z); // v.reduce_max() // }); // } @@ -900,9 +900,9 @@ fn autodiff_if_phi3() { let tid = dispatch_id().x(); let x = buf_x.read(tid); let y = buf_y.read(tid); - let const_two = var!(f32, 2.0); - let const_three = var!(f32, 3.0); - let const_four = var!(f32); + let const_two = 2.0_f32.var(); + let const_three = 3.0_f32.var(); + let const_four = F32Var::zeroed(); autodiff(|| { requires_grad(x); @@ -962,11 +962,11 @@ fn autodiff_if_phi4() { let x = buf_x.read(tid); let y = buf_y.read(tid); - let consts = var!(Float3); + let consts = Float3Var::zeroed(); autodiff(|| { requires_grad(x); requires_grad(y); - consts.store(make_float3(2.0, 3.0, 4.0)); + consts.store(Float3::expr(2.0, 3.0, 4.0)); let const_two = consts.x(); let const_three = consts.y(); let const_four = consts.z(); @@ -1102,8 +1102,8 @@ fn autodiff_callable() { let x = buf_x.read(tid); let y = buf_y.read(tid); let t = buf_t.read(tid); - let dx = def(x); - let dy = def(y); + let dx = x.var(); + let dy = y.var(); callable.call(dx, dy, t); buf_dx.write(tid, *dx); buf_dy.write(tid, *dy); diff --git a/luisa_compute/tests/common.rs b/luisa_compute/tests/common.rs index 2aa6222..2a96cff 100644 --- a/luisa_compute/tests/common.rs +++ b/luisa_compute/tests/common.rs @@ -1,6 +1,6 @@ -use std::env::current_exe; -use luisa::*; +use luisa::prelude::*; use luisa_compute as luisa; +use std::env::current_exe; fn _signal_handler(signal: libc::c_int) { if signal == libc::SIGSEGV { panic!("segfault detected"); @@ -20,7 +20,7 @@ pub fn get_device() -> Device { }; ONCE.call_once(|| unsafe { if show_log { - init_logger_verbose(); + luisa::init_logger_verbose(); } libc::signal(libc::SIGSEGV, _signal_handler as usize); }); @@ -31,4 +31,4 @@ pub fn get_device() -> Device { let device = ctx.create_device(&device); device.create_buffer_from_slice(&[1.0f32]); device -} \ No newline at end of file +} diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index fd98971..0c288ad 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1,5 +1,6 @@ +use luisa::lang::types::array::VLArrayVar; +use luisa::lang::types::core::*; use luisa::prelude::*; -use luisa::*; use luisa_compute as luisa; use luisa_compute_api_types::StreamTag; use rand::prelude::*; @@ -63,7 +64,7 @@ fn callable_return_mismatch() { let device = get_device(); let _abs = device.create_callable::) -> Expr>(&|x| { if_!(x.cmpgt(0.0), { - return_v(const_(true)); + return_v(true.expr()); }); -x }); @@ -75,7 +76,7 @@ fn callable_return_void_mismatch() { let device = get_device(); let _abs = device.create_callable::)>(&|x| { if_!(x.cmpgt(0.0), { - return_v(const_(true)); + return_v(true.expr()); }); x.store(-*x); }); @@ -130,7 +131,7 @@ fn callable() { let tid = dispatch_id().x(); let x = buf_x.read(tid); let y = buf_y.read(tid); - let z = var!(u32, add.call(x, y)); + let z = add.call(x, y).var(); write.call(buf_z, tid, z); buf_w.write(tid, z.load()); }); @@ -470,8 +471,8 @@ fn array_read_write() { let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); - let arr = local_zeroed::<[i32; 4]>(); - let i = local_zeroed::(); + let arr = Var::<[i32; 4]>::zeroed(); + let i = IntVar::zeroed(); while_!(i.load().cmplt(4), { arr.write(i.load().uint(), tid.int() + i.load()); i.store(i.load() + 1); @@ -494,7 +495,7 @@ fn array_read_write3() { let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); - let arr = local_zeroed::<[i32; 4]>(); + let arr = Var::<[i32; 4]>::zeroed(); for_range(0..4u32, |i| { arr.write(i, tid.int() + i.int()); }); @@ -516,7 +517,7 @@ fn array_read_write4() { let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); - let arr = local_zeroed::<[i32; 4]>(); + let arr = Var::<[i32; 4]>::zeroed(); for_range(0..6u32, |_| { for_range(0..4u32, |i| { arr.write(i, arr.read(i) + tid.int() + i.int()); @@ -547,8 +548,8 @@ fn array_read_write2() { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x(); - let arr = local_zeroed::<[i32; 4]>(); - let i = local_zeroed::(); + let arr = Var::<[i32; 4]>::zeroed(); + let i = IntVar::zeroed(); while_!(i.load().cmplt(4), { arr.write(i.load().uint(), tid.int() + i.load()); i.store(i.load() + 1); @@ -578,13 +579,13 @@ fn array_read_write_vla() { let buf_y = y.var(); let tid = dispatch_id().x(); let vl = VLArrayVar::::zero(4); - let i = local_zeroed::(); + let i = IntVar::zeroed(); while_!(i.load().cmplt(4), { vl.write(i.load().uint(), tid.int() + i.load()); i.store(i.load() + 1); }); - let arr = local_zeroed::<[i32; 4]>(); - let i = local_zeroed::(); + let arr = Var::<[i32; 4]>::zeroed(); + let i = IntVar::zeroed(); while_!(i.load().cmplt(4), { arr.write(i.load().uint(), vl.read(i.load().uint())); i.store(i.load() + 1); @@ -611,8 +612,8 @@ fn array_read_write_async_compile() { let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let tid = dispatch_id().x(); - let arr = local_zeroed::<[i32; 4]>(); - let i = local_zeroed::(); + let arr = Var::<[i32; 4]>::zeroed(); + let i = IntVar::zeroed(); while_!(i.load().cmplt(4), { arr.write(i.load().uint(), tid.int() + i.load()); i.store(i.load() + 1); @@ -636,7 +637,7 @@ fn capture_same_buffer_multiple_view() { x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); let shader = device.create_kernel::(&|| { - let tid = luisa::dispatch_id().x(); + let tid = dispatch_id().x(); let buf_x_lo = x.view(0..64).var(); let buf_x_hi = x.view(64..).var(); let x = if_!(tid.cmplt(64), { @@ -664,7 +665,7 @@ fn uniform() { x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); let shader = device.create_kernel::(&|v: Expr| { - let tid = luisa::dispatch_id().x(); + let tid = dispatch_id().x(); let buf_x_lo = x.view(0..64).var(); let buf_x_hi = x.view(64..).var(); let x = if_!(tid.cmplt(64), { @@ -683,7 +684,7 @@ fn uniform() { let expected = (x.len() as f32 - 1.0) * x.len() as f32 * 0.5 * 6.0; assert!((actual - expected).abs() < 1e-4); } -#[derive(Clone, Copy, Debug, __Value)] +#[derive(Clone, Copy, Debug, Value)] #[repr(C)] struct Big { a: [f32; 32], @@ -719,11 +720,11 @@ fn byte_buffer() { let i1 = i1 as u64; let i2 = i2 as u64; let i3 = i3 as u64; - let v0 = def(buf.read::(i0)); - let v1 = def(buf.read::(i1)); - let v2 = def(buf.read::(i2)); - let v3 = def(buf.read::(i3)); - *v0.get_mut() = make_float3(1.0, 2.0, 3.0); + let v0 = buf.read::(i0).var(); + let v1 = buf.read::(i1).var(); + let v2 = buf.read::(i2).var(); + let v3 = buf.read::(i3).var(); + *v0.get_mut() = Float3::expr(1.0, 2.0, 3.0); for_range(0u32..32u32, |i| { v1.a().write(i, i.float() * 2.0); }); @@ -795,11 +796,11 @@ fn bindless_byte_buffer() { let i1 = i1 as u64; let i2 = i2 as u64; let i3 = i3 as u64; - let v0 = def(buf.read::(i0)); - let v1 = def(buf.read::(i1)); - let v2 = def(buf.read::(i2)); - let v3 = def(buf.read::(i3)); - *v0.get_mut() = make_float3(1.0, 2.0, 3.0); + let v0 = buf.read::(i0).var(); + let v1 = buf.read::(i1).var(); + let v2 = buf.read::(i2).var(); + let v3 = buf.read::(i3).var(); + *v0.get_mut() = Float3::expr(1.0, 2.0, 3.0); for_range(0u32..32u32, |i| { v1.a().write(i, i.float() * 2.0); }); diff --git a/luisa_compute_derive/src/lib.rs b/luisa_compute_derive/src/lib.rs index 777e42d..1f5a064 100644 --- a/luisa_compute_derive/src/lib.rs +++ b/luisa_compute_derive/src/lib.rs @@ -1,39 +1,29 @@ use proc_macro::TokenStream; -use syn::{parse::{Parse, ParseStream}, __private::quote::{quote_spanned, quote}, spanned::Spanned}; +use syn::__private::quote::quote; +use syn::parse::{Parse, ParseStream}; +use syn::spanned::Spanned; #[proc_macro_derive(Value)] pub fn derive_value(item: TokenStream) -> TokenStream { let item: syn::ItemStruct = syn::parse(item).unwrap(); - let compiler = luisa_compute_derive_impl::Compiler::new(false); + let compiler = luisa_compute_derive_impl::Compiler; compiler.derive_value(&item).into() } #[proc_macro_derive(BindGroup, attributes(luisa))] pub fn derive_kernel_arg(item: TokenStream) -> TokenStream { let item: syn::ItemStruct = syn::parse(item).unwrap(); - let compiler = luisa_compute_derive_impl::Compiler::new(false); + let compiler = luisa_compute_derive_impl::Compiler; compiler.derive_kernel_arg(&item).into() } #[proc_macro_derive(Aggregate)] pub fn derive_aggregate(item: TokenStream) -> TokenStream { let item: syn::Item = syn::parse(item).unwrap(); - let compiler = luisa_compute_derive_impl::Compiler::new(false); + let compiler = luisa_compute_derive_impl::Compiler; compiler.derive_aggregate(&item).into() } -#[proc_macro_derive(__Value)] -pub fn _derive_value(item: TokenStream) -> TokenStream { - let item: syn::ItemStruct = syn::parse(item).unwrap(); - let compiler = luisa_compute_derive_impl::Compiler::new(true); - compiler.derive_value(&item).into() -} -#[proc_macro_derive(__Aggregate)] -pub fn _derive_aggregate(item: TokenStream) -> TokenStream { - let item: syn::Item = syn::parse(item).unwrap(); - let compiler = luisa_compute_derive_impl::Compiler::new(true); - compiler.derive_aggregate(&item).into() -} struct LogInput { printer: syn::Expr, level: syn::Expr, @@ -75,28 +65,28 @@ pub fn _log(item: TokenStream) -> TokenStream { .enumerate() .map(|(i, a)| syn::Ident::new(&format!("__log_priv_arg{}", i), a.span())) .collect::>(); - quote!{ + quote! { { - #( let #arg_idents = #args; )*; let mut __log_priv_i = 0; let log_fn = Box::new(move |args: &[*const u32]| -> () { let mut i = 0; - luisa_compute::log::log!(#level, #fmt , #( + luisa_compute::printer::_log::log!(#level, #fmt , #( { - let ret = luisa_compute::lang::printer::_unpack_from_expr(args[i], #arg_idents); + let ret = luisa_compute::printer::_unpack_from_expr(args[i], #arg_idents); i += 1; ret } ), *); }); - let mut printer_args = luisa_compute::lang::PrinterArgs::new(); + let mut printer_args = luisa_compute::printer::PrinterArgs::new(); #( printer_args.append(#arg_idents); )* #printer._log(#level, printer_args, log_fn); } - }.into() + } + .into() } diff --git a/luisa_compute_derive_impl/src/bin/derive-debug.rs b/luisa_compute_derive_impl/src/bin/derive-debug.rs index b282d24..2720bea 100644 --- a/luisa_compute_derive_impl/src/bin/derive-debug.rs +++ b/luisa_compute_derive_impl/src/bin/derive-debug.rs @@ -2,7 +2,7 @@ use luisa_compute_derive_impl::*; use quote::ToTokens; fn main() { - let compiler = Compiler::new(false); + let compiler = Compiler; let item: syn::ItemStruct = syn::parse_str( r#" #[derive(__Value)] diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index a3edfdc..eca0e83 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -1,23 +1,16 @@ -use std::collections::HashSet; - use proc_macro2::{TokenStream, TokenTree}; use quote::{quote, quote_spanned}; -use syn::{spanned::Spanned, Attribute, Item, ItemEnum, ItemFn, ItemStruct, ItemTrait}; -pub struct Compiler { - inside_crate: bool, -} +use syn::spanned::Spanned; +use syn::{Attribute, Item, ItemEnum, ItemFn, ItemStruct}; +pub struct Compiler; impl Compiler { - fn crate_path(&self) -> TokenStream { - if self.inside_crate { - quote!(crate::lang) - } else { - quote!(luisa_compute::lang) - } + fn lang_path(&self) -> TokenStream { + quote!(::luisa_compute::lang) } - pub fn new(inside_crate: bool) -> Self { - Self { inside_crate } + fn runtime_path(&self) -> TokenStream { + quote!(::luisa_compute::runtime) } - pub fn compile_kernel(&self, func: &ItemFn) -> TokenStream { + pub fn compile_kernel(&self, _func: &ItemFn) -> TokenStream { todo!() } fn check_repr_c(&self, attribtes: &Vec) { @@ -54,6 +47,7 @@ impl Compiler { } } pub fn derive_kernel_arg(&self, struct_: &ItemStruct) -> TokenStream { + let runtime_path = self.runtime_path(); let span = struct_.span(); let name = &struct_.ident; let vis = &struct_.vis; @@ -92,33 +86,34 @@ impl Compiler { let parameter_name = syn::Ident::new(&format!("{}Var", name), name.span()); let parameter_def = quote!( #vis struct #parameter_name #generics { - #(#field_vis #field_names: <#field_types as luisa_compute::runtime::KernelArg>::Parameter),* + #(#field_vis #field_names: <#field_types as #runtime_path::KernelArg>::Parameter),* } ); quote_spanned!(span=> #parameter_def - impl #impl_generics luisa_compute::lang::KernelParameter for #parameter_name #ty_generics #where_clause{ - fn def_param(builder: &mut luisa_compute::KernelBuilder) -> Self { + impl #impl_generics #runtime_path::KernelParameter for #parameter_name #ty_generics #where_clause{ + fn def_param(builder: &mut #runtime_path::KernelBuilder) -> Self { Self{ - #(#field_names: luisa_compute::lang::KernelParameter::def_param(builder)),* + #(#field_names: #runtime_path::KernelParameter::def_param(builder)),* } } } - impl #impl_generics luisa_compute::runtime::KernelArg for #name #ty_generics #where_clause{ + impl #impl_generics #runtime_path::KernelArg for #name #ty_generics #where_clause{ type Parameter = #parameter_name #ty_generics; - fn encode(&self, encoder: &mut luisa_compute::KernelArgEncoder) { + fn encode(&self, encoder: &mut #runtime_path::KernelArgEncoder) { #(self.#field_names.encode(encoder);)* } } - impl #impl_generics luisa_compute::runtime::AsKernelArg<#name #ty_generics> for #name #ty_generics #where_clause { + impl #impl_generics #runtime_path::AsKernelArg<#name #ty_generics> for #name #ty_generics #where_clause { } ) } pub fn derive_value(&self, struct_: &ItemStruct) -> TokenStream { self.check_repr_c(&struct_.attrs); let span = struct_.span(); - let crate_path = self.crate_path(); + let lang_path = self.lang_path(); + let runtime_path = self.runtime_path(); let generics = &struct_.generics; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let marker_args = generics @@ -156,18 +151,18 @@ impl Compiler { quote_spanned!(span=> #[allow(dead_code, non_snake_case)] #[allow(unused_parens)] - #vis fn #ident (&self) -> #crate_path ::Expr<#ty> { - use #crate_path ::*; + #vis fn #ident (&self) -> #lang_path::types::Expr<#ty> { + use #lang_path::*; as FromNode>::from_node(__extract::<#ty>( self.node, #i, )) } #[allow(dead_code, non_snake_case)] #[allow(unused_parens)] - #vis fn #set_ident<__T:Into<#crate_path ::Expr<#ty>>>(&self, value: __T) -> Self { - use #crate_path ::*; + #vis fn #set_ident<__T:Into<#lang_path::types::Expr<#ty>>>(&self, value: __T) -> Self { + use #lang_path::*; let value = value.into(); - Self::from_node(#crate_path ::__insert::<#name #ty_generics>(self.node, #i, ToNode::node(&value))) + Self::from_node(#lang_path::__insert::<#name #ty_generics>(self.node, #i, ToNode::node(&value))) } ) }) @@ -183,15 +178,15 @@ impl Compiler { quote_spanned!(span=> #[allow(dead_code, non_snake_case)] #[allow(unused_parens)] - #vis fn #ident (&self) -> #crate_path:: Var<#ty> { - use #crate_path ::*; + #vis fn #ident (&self) -> #lang_path::types::Var<#ty> { + use #lang_path::*; as FromNode>::from_node(__extract::<#ty>( self.node, #i, )) } #[allow(dead_code, non_snake_case)] #[allow(unused_parens)] - #vis fn #set_ident<__T:Into<#crate_path ::Expr<#ty>>>(&self, value: __T) { + #vis fn #set_ident<__T:Into<#lang_path::types::Expr<#ty>>>(&self, value: __T) { let value = value.into(); self.#ident().store(value); } @@ -207,7 +202,7 @@ impl Compiler { .map(|f| { let ident = f.ident.as_ref().unwrap(); let ty = &f.ty; - quote_spanned!(span=> #vis #ident: #crate_path ::Expr<#ty>) + quote_spanned!(span=> #vis #ident: #lang_path::types::Expr<#ty>) }) .collect::>(); quote_spanned!(span=> @@ -222,20 +217,20 @@ impl Compiler { ) }; let type_of_impl = quote_spanned!(span=> - impl #impl_generics #crate_path ::TypeOf for #name #ty_generics #where_clause { + impl #impl_generics #lang_path::ir::TypeOf for #name #ty_generics #where_clause { #[allow(unused_parens)] - fn type_() -> #crate_path ::CArc< #crate_path ::Type> { - use #crate_path ::*; + fn type_() -> #lang_path::ir::CArc< #lang_path::ir::Type> { + use #lang_path::*; let size = std::mem::size_of::<#name #ty_generics>(); let alignment = std::mem::align_of::<#name #ty_generics>(); - let struct_type = StructType { - fields: CBoxedSlice::new(vec![#(<#field_types as TypeOf>::type_(),)*]), + let struct_type = ir::StructType { + fields: ir::CBoxedSlice::new(vec![#(<#field_types as ir::TypeOf>::type_(),)*]), size, alignment }; - let type_ = Type::Struct(struct_type); + let type_ = ir::Type::Struct(struct_type); assert_eq!(std::mem::size_of::<#name #ty_generics>(), type_.size()); - register_type(type_) + ir::register_type(type_) } } ); @@ -244,21 +239,21 @@ impl Compiler { #[derive(Clone, Copy, Debug)] #[allow(unused_parens)] #vis struct #expr_proxy_name #generics{ - node: #crate_path ::NodeRef, + node: #lang_path::NodeRef, _marker: std::marker::PhantomData<(#marker_args)>, } #[derive(Clone, Copy, Debug)] #[allow(unused_parens)] #vis struct #var_proxy_name #generics{ - node: #crate_path ::NodeRef, + node: #lang_path::NodeRef, _marker: std::marker::PhantomData<(#marker_args)>, } #[allow(unused_parens)] - impl #impl_generics #crate_path ::Aggregate for #expr_proxy_name #ty_generics #where_clause { - fn to_nodes(&self, nodes: &mut Vec<#crate_path ::NodeRef>) { + impl #impl_generics #lang_path::Aggregate for #expr_proxy_name #ty_generics #where_clause { + fn to_nodes(&self, nodes: &mut Vec<#lang_path::NodeRef>) { nodes.push(self.node); } - fn from_nodes<__I: Iterator>(iter: &mut __I) -> Self { + fn from_nodes<__I: Iterator>(iter: &mut __I) -> Self { Self{ node: iter.next().unwrap(), _marker:std::marker::PhantomData @@ -266,11 +261,11 @@ impl Compiler { } } #[allow(unused_parens)] - impl #impl_generics #crate_path ::Aggregate for #var_proxy_name #ty_generics #where_clause { - fn to_nodes(&self, nodes: &mut Vec<#crate_path ::NodeRef>) { + impl #impl_generics #lang_path::Aggregate for #var_proxy_name #ty_generics #where_clause { + fn to_nodes(&self, nodes: &mut Vec<#lang_path::NodeRef>) { nodes.push(self.node); } - fn from_nodes<__I: Iterator>(iter: &mut __I) -> Self { + fn from_nodes<__I: Iterator>(iter: &mut __I) -> Self { Self{ node: iter.next().unwrap(), _marker:std::marker::PhantomData @@ -278,36 +273,36 @@ impl Compiler { } } #[allow(unused_parens)] - impl #impl_generics #crate_path ::FromNode for #expr_proxy_name #ty_generics #where_clause { + impl #impl_generics #lang_path::FromNode for #expr_proxy_name #ty_generics #where_clause { #[allow(unused_assignments)] - fn from_node(node: #crate_path ::NodeRef) -> Self { + fn from_node(node: #lang_path::NodeRef) -> Self { Self { node, _marker:std::marker::PhantomData } } } #[allow(unused_parens)] - impl #impl_generics #crate_path ::ToNode for #expr_proxy_name #ty_generics #where_clause { - fn node(&self) -> #crate_path ::NodeRef { + impl #impl_generics #lang_path::ToNode for #expr_proxy_name #ty_generics #where_clause { + fn node(&self) -> #lang_path::NodeRef { self.node } } #[allow(unused_parens)] - impl #impl_generics #crate_path ::ExprProxy for #expr_proxy_name #ty_generics #where_clause { + impl #impl_generics #lang_path::types::ExprProxy for #expr_proxy_name #ty_generics #where_clause { type Value = #name #ty_generics; } #[allow(unused_parens)] - impl #impl_generics #crate_path ::FromNode for #var_proxy_name #ty_generics #where_clause { + impl #impl_generics #lang_path::FromNode for #var_proxy_name #ty_generics #where_clause { #[allow(unused_assignments)] - fn from_node(node: #crate_path ::NodeRef) -> Self { + fn from_node(node: #lang_path::NodeRef) -> Self { Self { node, _marker:std::marker::PhantomData } } } - impl #impl_generics #crate_path ::ToNode for #var_proxy_name #ty_generics #where_clause { - fn node(&self) -> #crate_path ::NodeRef { + impl #impl_generics #lang_path::ToNode for #var_proxy_name #ty_generics #where_clause { + fn node(&self) -> #lang_path::NodeRef { self.node } } #[allow(unused_parens)] - impl #impl_generics #crate_path ::VarProxy for #var_proxy_name #ty_generics #where_clause { + impl #impl_generics #lang_path::types::VarProxy for #var_proxy_name #ty_generics #where_clause { type Value = #name #ty_generics; } #[allow(unused_parens)] @@ -324,20 +319,20 @@ impl Compiler { } } #[allow(unused_parens)] - impl #impl_generics #crate_path ::CallableParameter for #expr_proxy_name #ty_generics #where_clause { - fn def_param(_:Option>, builder: &mut #crate_path ::KernelBuilder) -> Self { + impl #impl_generics #runtime_path::CallableParameter for #expr_proxy_name #ty_generics #where_clause { + fn def_param(_:Option>, builder: &mut #runtime_path::KernelBuilder) -> Self { builder.value::<#name #ty_generics>() } - fn encode(&self, encoder: &mut #crate_path ::CallableArgEncoder) { + fn encode(&self, encoder: &mut #runtime_path::CallableArgEncoder) { encoder.var(*self) } } #[allow(unused_parens)] - impl #impl_generics #crate_path ::CallableParameter for #var_proxy_name #ty_generics #where_clause { - fn def_param(_:Option>, builder: &mut #crate_path ::KernelBuilder) -> Self { + impl #impl_generics #runtime_path::CallableParameter for #var_proxy_name #ty_generics #where_clause { + fn def_param(_:Option>, builder: &mut #runtime_path::KernelBuilder) -> Self { builder.var::<#name #ty_generics>() } - fn encode(&self, encoder: &mut #crate_path ::CallableArgEncoder) { + fn encode(&self, encoder: &mut #runtime_path::CallableArgEncoder) { encoder.var(*self) } } @@ -347,21 +342,21 @@ impl Compiler { span=> #proxy_def #type_of_impl - impl #impl_generics #crate_path ::Value for #name #ty_generics #where_clause{ + impl #impl_generics #lang_path::types::Value for #name #ty_generics #where_clause{ type Expr = #expr_proxy_name #ty_generics; type Var = #var_proxy_name #ty_generics; fn fields() -> Vec { vec![#(stringify!(#field_names).into(),)*] } } - impl #impl_generics #crate_path ::StructInitiaizable for #name #ty_generics #where_clause{ + impl #impl_generics #lang_path::StructInitiaizable for #name #ty_generics #where_clause{ type Init = #ctor_proxy_name #ty_generics; } impl #impl_generics #expr_proxy_name #ty_generics #where_clause { #(#expr_proxy_field_methods)* - #vis fn new(#(#field_names: impl Into<#crate_path ::Expr<#field_types>>),*) -> Self { - use #crate_path ::*; - let node = #crate_path ::__compose::<#name #ty_generics>(&[ #( ToNode::node(&#field_names.into()) ),* ]); + #vis fn new(#(#field_names: impl Into<#lang_path::types::Expr<#field_types>>),*) -> Self { + use #lang_path::*; + let node = #lang_path::__compose::<#name #ty_generics>(&[ #( ToNode::node(&#field_names.into()) ),* ]); Self { node, _marker:std::marker::PhantomData } } } @@ -372,18 +367,18 @@ impl Compiler { } pub fn derive_aggregate_for_struct(&self, struct_: &ItemStruct) -> TokenStream { let span = struct_.span(); - let crate_path = self.crate_path(); + let lang_path = self.lang_path(); let name = &struct_.ident; let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect(); let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); quote_spanned!(span=> - impl #crate_path ::Aggregate for #name { - fn to_nodes(&self, nodes: &mut Vec<#crate_path ::NodeRef>) { + impl #lang_path::Aggregate for #name { + fn to_nodes(&self, nodes: &mut Vec<#lang_path::NodeRef>) { #(self.#field_names.to_nodes(nodes);)* } - fn from_nodes<__I: Iterator>(iter: &mut __I) -> Self { - #(let #field_names = <#field_types as #crate_path ::Aggregate>::from_nodes(iter);)* + fn from_nodes<__I: Iterator>(iter: &mut __I) -> Self { + #(let #field_names = <#field_types as #lang_path::Aggregate>::from_nodes(iter);)* Self{ #(#field_names,)* } @@ -393,7 +388,7 @@ impl Compiler { } pub fn derive_aggregate_for_enum(&self, enum_: &ItemEnum) -> TokenStream { let span = enum_.span(); - let crate_path = self.crate_path(); + let lang_path = self.lang_path(); let name = &enum_.ident; let variants = &enum_.variants; let to_nodes = variants.iter().enumerate().map(|(i, v)|{ @@ -410,7 +405,7 @@ impl Compiler { quote_spanned! { field_span=> Self::#name{#(#named),*}=>{ - nodes.push(__new_user_node(#i)); + nodes.push(#lang_path::__new_user_node(#i)); #(#named.to_nodes(nodes);)* } } @@ -420,13 +415,13 @@ impl Compiler { quote_spanned! { field_span=> Self::#name(#(#fields),*)=>{ - nodes.push(__new_user_node(#i)); + nodes.push(#lang_path::__new_user_node(#i)); #(#fields.to_nodes(nodes);)* } } }, syn::Fields::Unit => { - quote_spanned! { field_span=> Self::#name => { nodes.push(#crate_path ::__new_user_node(#i)); } } + quote_spanned! { field_span=> Self::#name => { nodes.push(#lang_path::__new_user_node(#i)); } } } } }).collect::>(); @@ -442,7 +437,7 @@ impl Compiler { let fields = u.unnamed.iter().enumerate().map(|(i, f)| syn::Ident::new(&format!("f{}", i), f.span())).collect::>(); quote_spanned! { field_span=> #i=> { - #(let #fields: #field_types = #crate_path :: Aggregate ::from_nodes(iter);)* + #(let #fields: #field_types = #lang_path:: Aggregate ::from_nodes(iter);)* Self::#name(#(#fields),*) }, } @@ -467,9 +462,9 @@ impl Compiler { }) .collect::>(); quote_spanned! {span=> - impl #crate_path ::Aggregate for #name{ + impl #lang_path::Aggregate for #name{ #[allow(non_snake_case)] - fn from_nodes>(iter: &mut I) -> Self { + fn from_nodes>(iter: &mut I) -> Self { let variant = iter.next().unwrap(); let variant = variant.unwrap_user_data::(); match variant{ @@ -478,7 +473,7 @@ impl Compiler { } } #[allow(non_snake_case)] - fn to_nodes(&self, nodes: &mut Vec){ + fn to_nodes(&self, nodes: &mut Vec<#lang_path::NodeRef>){ match self { #(#to_nodes)* } diff --git a/luisa_compute_sys/build.rs b/luisa_compute_sys/build.rs index 161b46f..3071336 100644 --- a/luisa_compute_sys/build.rs +++ b/luisa_compute_sys/build.rs @@ -1,6 +1,5 @@ -use std::io; -use std::path::Path; -use std::{env, fs, path::PathBuf}; +use std::path::{Path, PathBuf}; +use std::{env, fs}; fn cmake_build() -> PathBuf { let mut config = cmake::Config::new("./LuisaCompute"); @@ -113,8 +112,7 @@ fn copy_dlls(out_dir: &PathBuf) { for entry in std::fs::read_dir(out_dir).unwrap() { let entry = entry.unwrap(); let path = entry.path(); - if is_path_dll(&path) - { + if is_path_dll(&path) { // let target_dir = get_output_path(); let comps: Vec<_> = path.components().collect(); let copy_if_different = |src, dst| { @@ -191,4 +189,3 @@ fn main() { let out_dir = cmake_build(); copy_dlls(&out_dir); } - diff --git a/luisa_compute_track/Cargo.toml b/luisa_compute_track/Cargo.toml new file mode 100644 index 0000000..e208f38 --- /dev/null +++ b/luisa_compute_track/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "luisa_compute_track" +version = "0.1.1-alpha.1" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies.syn] +version = "2.0" +features = ["full", "visit-mut"] + +[dev-dependencies] +pretty_assertions = "1.4.0" + +[dependencies] +proc-macro-error = "1.0.4" +proc-macro2 = "1.0.67" +quote = "1.0" diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs new file mode 100644 index 0000000..2fddca8 --- /dev/null +++ b/luisa_compute_track/src/lib.rs @@ -0,0 +1,208 @@ +use proc_macro2::TokenStream; +use proc_macro_error::emit_error; +use quote::quote; +use syn::spanned::Spanned; +use syn::visit_mut::*; +use syn::*; + +#[cfg(test)] +use pretty_assertions::assert_eq; + +struct TraceVisitor { + trait_path: TokenStream, + flow_path: TokenStream, +} + +impl VisitMut for TraceVisitor { + fn visit_block_mut(&mut self, node: &mut Block) { + let len = node.stmts.len(); + if len > 0 { + for stmt in node.stmts[0..len - 1].iter_mut() { + self.visit_stmt_mut(stmt); + } + visit_stmt_mut(self, node.stmts.last_mut().unwrap()); + } + } + fn visit_stmt_mut(&mut self, node: &mut Stmt) { + let span = node.span(); + match node { + Stmt::Expr(_, semi) => { + if semi.is_none() { + *semi = Some(Token![;](span)); + } + } + _ => {} + } + visit_stmt_mut(self, node); + } + fn visit_expr_mut(&mut self, node: &mut Expr) { + let flow_path = &self.flow_path; + let trait_path = &self.trait_path; + let span = node.span(); + match node { + Expr::If(expr) => { + let cond = &expr.cond; + let then_branch = &expr.then_branch; + let else_branch = &expr.else_branch; + if let Expr::Let(_) = **cond { + } else if let Some((_, else_branch)) = else_branch { + *node = parse_quote_spanned! {span=> + <_ as #trait_path::BoolIfElseMaybeExpr<_>>::if_then_else(#cond, || #then_branch, || #else_branch) + } + } else { + *node = parse_quote_spanned! {span=> + <_ as #trait_path::BoolIfMaybeExpr>::if_then(#cond, || #then_branch) + } + } + } + Expr::While(expr) => { + let cond = &expr.cond; + let body = &expr.body; + *node = parse_quote_spanned! {span=> + <_ as #trait_path::BoolWhileMaybeExpr>::while_loop(|| #cond, || #body) + } + } + Expr::Loop(expr) => { + let body = &expr.body; + *node = parse_quote_spanned! {span=> + #flow_path::loop_!(|| #body) + } + } + Expr::ForLoop(expr) => { + let pat = &expr.pat; + let body = &expr.body; + let expr = &expr.expr; + if let Expr::Range(range) = &**expr { + *node = parse_quote_spanned! {span=> + #flow_path::for_range(#range, |#pat| #body) + } + } + } + Expr::Binary(expr) => { + let op_fn_str = match &expr.op { + BinOp::Eq(_) => "eq", + BinOp::Ne(_) => "ne", + + BinOp::And(_) => "and", + BinOp::Or(_) => "or", + + BinOp::Lt(_) => "lt", + BinOp::Le(_) => "le", + BinOp::Ge(_) => "ge", + BinOp::Gt(_) => "gt", + _ => "", + }; + + if !op_fn_str.is_empty() { + let left = &expr.left; + let right = &expr.right; + let op_fn = Ident::new(op_fn_str, expr.op.span()); + if op_fn_str == "eq" || op_fn_str == "ne" { + *node = parse_quote_spanned! {span=> + <_ as #trait_path::EqMaybeExpr<_>>::#op_fn(#left, #right) + } + } else if op_fn_str == "and" || op_fn_str == "or" { + *node = parse_quote_spanned! {span=> + <_ as #trait_path::BoolLazyOpsMaybeExpr<_>>::#op_fn(#left, || #right) + } + } else { + *node = parse_quote_spanned! {span=> + <_ as #trait_path::PartialOrdMaybeExpr<_>>::#op_fn(#left, #right) + } + } + } + } + Expr::Return(expr) => { + if let Some(expr) = &expr.expr { + *node = parse_quote_spanned! {span=> + #flow_path::return_v(#expr) + }; + } else { + *node = parse_quote_spanned! {span=> + #flow_path::return_() + }; + } + } + Expr::Continue(expr) => { + if expr.label.is_some() { + emit_error!( + span, + "continue expression tracing with labels is not supported\nif this is intended to be a normal loop, use the `escape!` macro" + ); + } else { + *node = parse_quote_spanned! {span=> + #flow_path::continue_() + }; + } + } + Expr::Break(expr) => { + if expr.label.is_some() { + emit_error!( + span, + "break expression tracing with labels is not supported\nif this is intended to be a normal loop, use the `escape!` macro" + ); + } else { + *node = parse_quote_spanned! {span=> + #flow_path::break_() + }; + } + } + Expr::Macro(expr) => { + let path = &expr.mac.path; + if path.leading_colon.is_none() + && path.segments.len() == 1 + && path.segments[0].arguments.is_none() + { + let ident = &path.segments[0].ident; + if *ident == "escape" { + let tokens = &expr.mac.tokens; + *node = parse_quote_spanned! {span=> + #tokens + }; + return; + } + } + } + _ => {} + } + visit_expr_mut(self, node); + } +} + +#[proc_macro] +pub fn track(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + track_impl(parse_macro_input!(input as Expr)).into() +} + +fn track_impl(mut ast: Expr) -> TokenStream { + (TraceVisitor { + flow_path: quote!(::luisa_compute::lang::control_flow), + trait_path: quote!(::luisa_compute::lang::maybe_expr), + }) + .visit_expr_mut(&mut ast); + + quote!(#ast) +} + +#[test] +fn test_macro() { + #[rustfmt::skip] + assert_eq!( + track_impl(parse_quote!(|x: Expr, y: Expr| { + if x > y { + x * y + } else { + y * x + (x / 32.0 * PI).sin() + } + })) + .to_string(), + quote!(|x: Expr, y: Expr| { + <_ as ::luisa_compute::lang::maybe_expr::BoolIfElseMaybeExpr<_> >::if_then_else( + <_ as ::luisa_compute::lang::maybe_expr::PartialOrdMaybeExpr<_> >::gt(x, y), + | | { x * y }, + | | { y * x + (x / 32.0 * PI).sin() } + ) + }) + .to_string() + ); +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..d1bdb5e --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +ignore = ["luisa_compute_sys"] +imports_granularity = "Module"