Skip to content

Commit

Permalink
fixed ssa
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 3, 2023
1 parent d175099 commit cbab1f9
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 13 deletions.
91 changes: 79 additions & 12 deletions luisa_compute/tests/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,13 @@ fn autodiff_vec3_cross_x() {
let ax = inputs[0];
let ay = inputs[1];
let az = inputs[2];
let a = make_float3(ax, ay, az);
let a = def(make_float3(ax, ay, az));
let bx = inputs[3];
let by = inputs[4];
let bz = inputs[5];
let b = make_float3(bx, by, bz);
let v = a.cross(b);
v.x()
let b = def(make_float3(bx, by, bz));
let v = def(a.cross(*b));
*v.x()
});
}
#[test]
Expand All @@ -395,13 +395,13 @@ fn autodiff_vec3_cross_y() {
let ax = inputs[0];
let ay = inputs[1];
let az = inputs[2];
let a = make_float3(ax, ay, az);
let a = def(make_float3(ax, ay, az));
let bx = inputs[3];
let by = inputs[4];
let bz = inputs[5];
let b = make_float3(bx, by, bz);
let v = a.cross(b);
v.y()
let b = def(make_float3(bx, by, bz));
let v = def(a.cross(*b));
*v.x()
});
}

Expand Down Expand Up @@ -918,15 +918,82 @@ fn autodiff_if_phi3() {
let tid = dispatch_id().x();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let const_two = var!(f32, 2.0);
let const_three = var!(f32, 3.0);
let const_four = var!(f32);

autodiff(|| {
requires_grad(x);
requires_grad(y);
const_four.store(4.0);
let c = x.cmpgt(*const_three).int();
let z = if_!(x.cmpgt(y), {
switch::<Expr<f32>>(c)
.case(0, || x * *const_two)
.default(|| x * *const_four)
.finish() * *const_two
}, else {
y * 0.5
});
backward(z);
buf_dx.write(tid, gradient(x));
buf_dy.write(tid, gradient(y));
});
});
kernel.dispatch([1024, 1, 1]);
let dx = dx.view(..).copy_to_vec();
let dy = dy.view(..).copy_to_vec();
let x = x.view(..).copy_to_vec();
let y = y.view(..).copy_to_vec();
let cache_dir = kernel.cache_dir();
for i in 0..1024 {
if x[i] > y[i] {
if x[i] > 3.0 {
assert_eq!(dx[i], 8.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir);
} else {
assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir);
}
} else {
assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.5, "{} cache_dir: {:?}", dy[i], cache_dir);
}
}
}
#[test]
fn autodiff_if_phi4() {
let device = get_device();
let x: Buffer<f32> = device.create_buffer(1024);
let y: Buffer<f32> = device.create_buffer(1024);
let dx: Buffer<f32> = device.create_buffer(1024);
let dy: Buffer<f32> = device.create_buffer(1024);
let mut rng = rand::thread_rng();
x.view(..).fill_fn(|_| rng.gen());
y.view(..).fill_fn(|_| rng.gen());
let kernel = device.create_kernel::<()>(&|| {
let buf_x = x.var();
let buf_y = y.var();
let buf_dx = dx.var();
let buf_dy = dy.var();
let tid = dispatch_id().x();
let x = buf_x.read(tid);
let y = buf_y.read(tid);

let consts = var!(Float3);
autodiff(|| {
requires_grad(x);
requires_grad(y);
let c = x.cmpgt(3.0).int();
consts.store(make_float3(2.0,3.0,4.0));
let const_two = consts.x();
let const_three = consts.y();
let const_four = consts.z();
let c = x.cmpgt(*const_three).int();
let z = if_!(x.cmpgt(y), {
switch::<Expr<f32>>(c)
.case(0, || x * 2.0)
.default(|| x * 4.0)
.finish() * 2.0
.case(0, || x * *const_two)
.default(|| x * *const_four)
.finish() * *const_two
}, else {
y * 0.5
});
Expand Down

0 comments on commit cbab1f9

Please sign in to comment.