Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete Refactor, also add track macro and ::expr, .var commands. #10

Merged
merged 20 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ members = [
"luisa_compute_sys",
"luisa_compute_derive_impl",
"luisa_compute_derive",
"luisa_compute_track",
]
resolver = "2"
# exclude = [
# "luisa_compute_sys/LuisaCompute/src/api/luisa_compute_api_types",
# ]
# ]
144 changes: 96 additions & 48 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,33 @@ Rust frontend to LuisaCompute and more! Unified API and embedded DSL for high pe

To see the use of `luisa-compute-rs` in a high performance offline rendering system, checkout [our research renderer](https://github.com/shiinamiyuki/akari_render)
## Table of Contents
* [Overview](#overview)
+ [Embedded Domain-Specific Language](#embedded-domain-specific-language)
+ [Automatic Differentiation](#automatic-differentiation)
+ [A CPU backend](#cpu-backend)
+ [IR Module for EDSL](#ir-module)
+ [Debuggability](#debuggability)
* [Usage](#usage)
+ [Building](#building)
+ [Variables and Expressions](#variables-and-expressions)
+ [Builtin Functions](#builtin-functions)
+ [Control Flow](#control-flow)
+ [Custom Data Types](#custom-data-types)
+ [Polymorphism](#polymorphism)
+ [Autodiff](#autodiff)
+ [Custom Operators](#custom-operators)
+ [Callable](#callable)
+ [Kernel](#kernel)
* [Advanced Usage](#advanced-usage)
* [Safety](#safety)
* [Citation](#citation)
- [luisa-compute-rs](#luisa-compute-rs)
- [Table of Contents](#table-of-contents)
- [Example](#example)
- [Vecadd](#vecadd)
- [Overview](#overview)
- [Embedded Domain-Specific Language](#embedded-domain-specific-language)
- [Automatic Differentiation](#automatic-differentiation)
- [CPU Backend](#cpu-backend)
- [IR Module](#ir-module)
- [Debuggability](#debuggability)
- [Usage](#usage)
- [Building](#building)
- [Variables and Expressions](#variables-and-expressions)
- [Builtin Functions](#builtin-functions)
- [Control Flow](#control-flow)
- [`track!` Mcro](#track-mcro)
- [Custom Data Types](#custom-data-types)
- [Polymorphism](#polymorphism)
- [Autodiff](#autodiff)
- [Custom Operators](#custom-operators)
- [Callable](#callable)
- [Kernel](#kernel)
- [Advanced Usage](#advanced-usage)
- [Safety](#safety)
- [API](#api)
- [Backend](#backend)
- [Citation](#citation)

## Example
Try `cargo run --release --example path_tracer -- [cpu|cuda|dx|metal]`!
Expand Down Expand Up @@ -60,7 +67,7 @@ fn main() {
let tid = dispatch_id().x();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let vx = var!(f32); // create a local mutable variable
let vx = 0.0f32.var(); // create a local mutable variable
*vx.get_mut() += x;
buf_z.write(tid, vx.load() + y);
});
Expand Down Expand Up @@ -125,52 +132,44 @@ For each type, there are two EDSL proxy objects `Expr<T>` and `Var<T>`. `Expr<T>
*Note*: Every DSL object in host code **must** be immutable due to Rust unable to overload `operator =`. For example:
```rust
// **no good**
let mut v = const_(0.0f32);
let mut v = 0.0f32.expr();
if_!(cond, {
v += 1.0;
});

// also **not good**
let v = Cell::new(const_(0.0f32));
let v = Cell::new(0.0f32.expr());
if_!(cond, {
v.set(v.get() + 1.0);
});

// **good**
let v = var!(f32);
let v = 0.0f32.var();
if_!(cond, {
*v.get_mut() += 1.0;
});
```
*Note*: You should not store the referene obtained by `v.get_mut()` for repeated use, as the assigned value is only updated when `v.get_mut()` is dropped. For example,:
```rust
let v = var!(f32);
let v = 0.0f32.var();
let bad = v.get_mut();
*bad = 1.0;
let u = *v;
drop(bad);
cpu_dbg!(u); // prints 0.0
cpu_dbg!(*v); // prints now 1.0
```
All operations except load/store should be performed on `Expr<T>`. `Var<T>` can only be used to load/store values. While `Expr<T>` and `Var<T>` are sufficent in most cases, it cannot be placed in an `impl` block. To do so, the exact name of these proxies are needed.
```rust
Expr<Bool> == Bool, Var<Bool> == BoolVar
Expr<f32> == Float32, Var<f32> == Float32Var
Expr<i32> == Int32, Var<i32> == Int32Var
Expr<u32> == UInt32, Var<u32> == UInt32Var
Expr<i64> == Int64, Var<i64> == Int64Var
Expr<u64> == UInt64, Var<u64> == UInt64Var
```
All operations except load/store should be performed on `Expr<T>`. `Var<T>` can only be used to load/store values.

As in the C++ EDSL, we additionally supports the following vector/matrix types. Their proxy types are `XXXExpr` and `XXXVar`:

```rust
Bool2 // bool2 in C++
Bool3 // bool3 in C++
Bool4 // bool4 in C++
Vec2 // float2 in C++
Vec3 // float3 in C++
Vec4 // float4 in C++
Float2 // float2 in C++
Float3 // float3 in C++
Float4 // float4 in C++
Int2 // int2 in C++
Int3 // int3 in C++
Int4 // int4 in C++
Expand All @@ -181,20 +180,19 @@ Mat2 // float2x2 in C++
Mat3 // float3x3 in C++
Mat4 // float4x4 in C++
```
Array types `[T;N]` are also supported and their proxy types are `ArrayExpr<T, N>` and `ArrayVar<T, N>`. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar<T, N>` for element access. `ArrayExpr<T,N>` can be stored to and loaded from `ArrayVar<T, N>`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar<T>`. `VLArrayVar<T>::zero(length: usize` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar<T>` in host, use ``VLArrayVar<T>::static_len()->usize`. To query the length in kernel, use ``VLArrayVar<T>::len()->Expr<u32>`
Array types `[T;N]` are also supported and their proxy types are `ArrayExpr<T, N>` and `ArrayVar<T, N>`. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar<T, N>` for element access. `ArrayExpr<T,N>` can be stored to and loaded from `ArrayVar<T, N>`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar<T>`. `VLArrayVar<T>::zero(length: usize)` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar<T>` in host, use `VLArrayVar<T>::static_len()->usize`. To query the length in kernel, use `VLArrayVar<T>::len()->Expr<u32>`

Most operators are already overloaded with the only exception is comparision. We cannot overload comparision operators as `PartialOrd` cannot return a DSL type. Instead, use `cmpxx` methods such as `cmpgt, cmpeq`, etc. To cast a primitive/vector into another type, use `v.type()`. For example:
```rust
let iv = make_int2(1,1,1);
let iv = Int2::expr(1, 1, 1);
let fv = iv.float(); //fv is Expr<Float2>
let bv = fv.bool(); // bv is Expr<Bool2>
```
To perform a bitwise cast, use the `bitcast` function. `let fv:Expr<f32> = bitcast::<u32, f32>(const_(0u32));`
To perform a bitwise cast, use the `bitcast` function. `let fv:Expr<f32> = bitcast::<u32, f32>(0u32);`

### Builtin Functions

We have extentded primitive types with methods similar to their host counterpart: `v.sin(), v.max(u)`, etc. Most methods accepts both a `Expr<T>` or a literal like `0.0`. However, the `select` function is slightly different as it do not accept literals. You need to use `select(cond, f_var, const_(1.0f32))`.

We have extentded primitive types with methods similar to their host counterpart: `v.sin(), v.max(u)`, etc. Most methods accepts both a `Expr<T>` or a literal like `0.0`. However, the `select` function is slightly different as it does not accept literals. You need to use `select(cond, f_var, 1.0f32.expr())`.

### Control Flow
*Note*, you cannot modify outer scope variables inside a control flow block by declaring the variable as `mut`. To modify outer scope variables, use `Var<T>` instead and call *var.get_mut() = value` to store the value back to the outer scope.
Expand Down Expand Up @@ -223,8 +221,60 @@ let (x,y) = switch::<(Expr<i32>, Expr<f32>)>(value)
.finish();
```

### `track!` Mcro

We also offer a `track!` macro that automatically rewrites control flow primitves and comparison operators. For example (from [`examples/mpm.rs`](luisa_compute/examples/mpm.rs)):

```rust
track!(|| {
// ...
let vx = select(
coord.x() < BOUND && (vx < 0.0f32)
|| coord.x() + BOUND > N_GRID as u32 && (vx > 0.0f32),
0.0f32.into(),
vx,
);
let vy = select(
coord.y() < BOUND && (vy < 0.0f32)
|| coord.y() + BOUND > N_GRID as u32 && (vy > 0.0f32),
0.0f32.into(),
vy,
);
// ...
})
```
is equivalent to:
```rust
|| {
// ...
let vx = select(
(coord.x().cmplt(BOUND) & vx.cmplt(0.0f32))
| (coord.x() + BOUND).cmpgt(N_GRID as u32) & vx.cmpgt(0.0f32),
0.0f32.into(),
vx,
);
let vy = select(
(coord.y().cmplt(BOUND) & vy.cmplt(0.0f32))
| (coord.y() + BOUND).cmpgt(N_GRID as u32) & vy.cmpgt(0.0f32),
0.0f32.into(),
vy,
);
// ...
}
```
Similarily,
```rust
track!(if cond { foo } else if bar { baz } else { qux })
```
will be converted to
```rust
if_!(cond, { foo }, { if_!(bar, { baz }, { qux }) })
```

Note that this macro will rewrite `while`, `for _ in x..y`, and `loop` expressions to versions using functions, which will then break the `break` and `continue` expressions. In order to avoid this, it's possible to use the `escape!` macro within a `track!` context to disable rewriting for an expression.

### Custom Data Types
To add custom data types to the EDSL, simply derive from `luisa::Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. The proxy types are `XXXExpr` and `XXXVar`:
To add custom data types to the EDSL, simply derive from `Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. The proxy types are `XXXExpr` and `XXXVar`:

```rust
#[derive(Copy, Clone, Default, Debug, Value)]
Expand All @@ -234,7 +284,7 @@ pub struct MyVec2 {
pub y: f32,
}

let v = var!(MyVec2);
let v = MyVec2.var();
let sum = *v.x() + *v.y();
*v.x().get_mut() += 1.0;
```
Expand Down Expand Up @@ -282,8 +332,6 @@ autodiff(||{
buf_dv.write(.., dv);
buf_dm.write(.., dm);
});


```

### Custom Operators
Expand All @@ -304,8 +352,8 @@ let my_add = CpuFn::new(|args: &mut MyAddArgs| {

let args = MyAddArgsExpr::new(x, y, Float32::zero());
let result = my_add.call(args);

```

### Callable
Users can define device-only functions using Callables. Callables have similar type signature to kernels: `Callable<fn(Args)->Ret>`.
The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using `Device::create_callable`. To invoke a Callable, use `Callable::call(args...)`. Callables accepts arguments such as resources (`BufferVar<T>`, .etc), expressions and references (pass a `Var<T>` to the callable). For example:
Expand All @@ -317,7 +365,7 @@ let z = add.call(x, y);
let pass_by_ref = device.create_callable::<fn(Var<f32>)>(&|a| {
*a.get_mut() += 1.0;
});
let a = var!(f32, 1.0);
let a = 1.0f32.var();
pass_by_ref.call(a);
cpu_dbg!(*a); // prints 2.0
```
Expand Down
17 changes: 9 additions & 8 deletions luisa_compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ name = "luisa_compute"
version = "0.1.1-alpha.1"

[dependencies]
base64ct = {version = "1.5.0", features = ["alloc"]}
base64ct = { version = "1.5.0", features = ["alloc"] }
bumpalo = "3.12.0"
env_logger = "0.10.0"
glam = "0.24.0"
Expand All @@ -14,15 +14,16 @@ lazy_static = "1.4.0"
libc = "0.2"
libloading = "0.8"
log = "0.4"
luisa_compute_api_types = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version="0.1.1-alpha.1"}
luisa_compute_backend = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version="0.1.1-alpha.1"}
luisa_compute_derive = {path = "../luisa_compute_derive", version="0.1.1-alpha.1"}
luisa_compute_derive_impl = {path = "../luisa_compute_derive_impl", version="0.1.1-alpha.1"}
luisa_compute_ir = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version="0.1.1-alpha.1"}
luisa_compute_sys = {path = "../luisa_compute_sys", version="0.1.1-alpha.1"}
luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" }
luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" }
luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" }
luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" }
luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" }
luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" }
luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" }
parking_lot = "0.12.1"
rayon = "1.6.0"
serde = {version = "1.0", features = ["derive"]}
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sha2 = "0.10"
winit = "0.28.3"
Expand Down
7 changes: 3 additions & 4 deletions luisa_compute/examples/atomic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::env::current_exe;

use luisa::prelude::*;
use luisa::Context;
use luisa_compute as luisa;

fn main() {
Expand All @@ -11,12 +10,12 @@ fn main() {
let sum = device.create_buffer::<f32>(1);
x.view(..).fill_fn(|i| i as f32);
sum.view(..).fill(0.0);
let shader = device.create_kernel::<fn()>(&|| {
let shader = device.create_kernel::<fn()>(&track!(|| {
let buf_x = x.var();
let buf_sum = sum.var();
let tid = luisa::dispatch_id().x();
let tid = dispatch_id().x();
buf_sum.atomic_fetch_add(0, buf_x.read(tid));
});
}));
shader.dispatch([x.len() as u32, 1, 1]);
let mut sum_data = vec![0.0];
sum.view(..).copy_to(&mut sum_data);
Expand Down
20 changes: 12 additions & 8 deletions luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{env::current_exe, f32::consts::PI};
use std::env::current_exe;
use std::f32::consts::PI;

use luisa::*;
use luisa::lang::diff::*;
use luisa::prelude::*;
use luisa_compute as luisa;
fn main() {
luisa::init_logger_verbose();
Expand Down Expand Up @@ -31,11 +33,13 @@ fn main() {
let buf_y = y.var();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let f = |x: Expr<f32>, y: Expr<f32>| {
if_!(x.cmpgt(y), { x * y }, else, {
let f = track!(|x: Expr<f32>, y: Expr<f32>| {
if x > y {
x * y
} else {
y * x + (x / 32.0 * PI).sin()
})
};
}
});
autodiff(|| {
requires_grad(x);
requires_grad(y);
Expand All @@ -45,8 +49,8 @@ fn main() {
dy_rev.write(tid, gradient(y));
});
forward_autodiff(2, || {
propagate_gradient(x, &[const_(1.0f32), const_(0.0f32)]);
propagate_gradient(y, &[const_(0.0f32), const_(1.0f32)]);
propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]);
propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]);
let z = f(x, y);
let dx = output_gradients(z)[0];
let dy = output_gradients(z)[1];
Expand Down
8 changes: 4 additions & 4 deletions luisa_compute/examples/backtrace.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::env::{current_exe, self};
use std::env::{self, current_exe};

use luisa::prelude::*;
use luisa_compute as luisa;

fn main() {
use luisa::*;
init_logger();
luisa::init_logger();
let ctx = Context::new(current_exe().unwrap());
env::set_var("LUISA_DEBUG", "1");
let device = ctx.create_device("cpu");
Expand All @@ -20,7 +20,7 @@ fn main() {
let tid = dispatch_id().x();
let x = buf_x.read(tid + 123);
let y = buf_y.read(tid);
let vx = var!(f32); // create a local mutable variable
let vx = Var::<f32>::zeroed(); // create a local mutable variable
vx.store(x);
buf_z.write(tid, vx.load() + y);
});
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindgroup.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::env::current_exe;

use luisa::*;
use luisa::prelude::*;
use luisa_compute as luisa;
#[derive(BindGroup)]
struct MyArgStruct<T: Value> {
Expand Down
Loading