-
Notifications
You must be signed in to change notification settings - Fork 7
/
polymorphism_advanced.rs
159 lines (156 loc) · 4.88 KB
/
polymorphism_advanced.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use std::env::current_exe;
use std::f32::consts::PI;
use luisa::lang::poly::*;
use luisa::prelude::*;
use luisa_compute as luisa;
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub enum ShaderDevirtualizationKey {
ConstShader,
SinShader,
AddShader(
Box<ShaderDevirtualizationKey>,
Box<ShaderDevirtualizationKey>,
),
}
#[derive(Clone, Copy)]
pub struct ShaderEvalContext<'a> {
pub poly_shader: &'a Polymorphic<ShaderDevirtualizationKey, dyn ShaderNode>,
pub key: &'a ShaderDevirtualizationKey,
}
pub trait ShaderNode {
fn evaluate(&self, sp: Expr<f32>, ctx: &ShaderEvalContext<'_>) -> Expr<f32>;
}
#[derive(Value, Clone, Copy)]
#[repr(C)]
pub struct ConstShader {
value: f32,
}
impl ShaderNode for ConstShaderExpr {
#[tracked]
fn evaluate(&self, _: Expr<f32>, _ctx: &ShaderEvalContext<'_>) -> Expr<f32> {
self.value
}
}
impl_polymorphic!(ShaderNode, ConstShader);
#[derive(Value, Clone, Copy)]
#[repr(C)]
pub struct SinShader {
_pad: u32,
}
impl ShaderNode for SinShaderExpr {
fn evaluate(&self, x: Expr<f32>, _ctx: &ShaderEvalContext<'_>) -> Expr<f32> {
x.sin()
}
}
impl_polymorphic!(ShaderNode, SinShader);
#[derive(Value, Clone, Copy)]
#[repr(C)]
pub struct AddShader {
pub shader_a: TagIndex,
pub shader_b: TagIndex,
}
fn eval_recursive_shader(
shader: PolymorphicRef<'_, ShaderDevirtualizationKey, dyn ShaderNode>,
x: Expr<f32>,
ctx: &ShaderEvalContext<'_>,
) -> Expr<f32> {
let tag = shader.tag_from_key(ctx.key).unwrap();
shader.unwrap(tag, |key, shader| {
assert_eq!(key, ctx.key);
shader.evaluate(x, ctx)
})
}
impl ShaderNode for AddShaderExpr {
fn evaluate(&self, x: Expr<f32>, ctx: &ShaderEvalContext<'_>) -> Expr<f32> {
let key = ctx.key;
match key {
ShaderDevirtualizationKey::AddShader(a, b) => {
let shader_a = ctx.poly_shader.get(self.shader_a);
let shader_b = ctx.poly_shader.get(self.shader_b);
let value_a = eval_recursive_shader(
shader_a,
x,
&ShaderEvalContext {
poly_shader: ctx.poly_shader,
key: a.as_ref(),
},
);
let value_b = eval_recursive_shader(
shader_b,
x,
&ShaderEvalContext {
poly_shader: ctx.poly_shader,
key: b.as_ref(),
},
);
track!(value_a + value_b)
}
_ => unreachable!(),
}
}
}
impl_polymorphic!(ShaderNode, AddShader);
fn main() {
let ctx = luisa::Context::new(current_exe().unwrap());
let device = ctx.create_device("cpu");
let mut builder =
PolymorphicBuilder::<ShaderDevirtualizationKey, dyn ShaderNode>::new(device.clone());
// build shader = sin(x) + (1.0 + 2.0)
let shader_const_1 = builder.push(
ShaderDevirtualizationKey::ConstShader,
ConstShader { value: 1.0 },
);
let shader_const_2 = builder.push(
ShaderDevirtualizationKey::ConstShader,
ConstShader { value: 2.0 },
);
let shader_sin = builder.push(ShaderDevirtualizationKey::SinShader, SinShader { _pad: 0 });
let shader_add_1_2_key = ShaderDevirtualizationKey::AddShader(
Box::new(ShaderDevirtualizationKey::ConstShader),
Box::new(ShaderDevirtualizationKey::ConstShader),
);
let shader_add_1_2 = builder.push(
shader_add_1_2_key.clone(),
AddShader {
shader_a: shader_const_1,
shader_b: shader_const_2,
},
);
let shader_final_key = ShaderDevirtualizationKey::AddShader(
Box::new(ShaderDevirtualizationKey::SinShader),
Box::new(shader_add_1_2_key),
);
let shader_final = builder.push(
shader_final_key.clone(),
AddShader {
shader_a: shader_sin,
shader_b: shader_add_1_2,
},
);
let poly_shader = builder.build();
let result = device.create_buffer::<f32>(100);
let kernel = Kernel::<fn()>::new(
&device,
&track!(|| {
let i = dispatch_id().x;
let x = i.as_f32() / 100.0 * PI;
let ctx = ShaderEvalContext {
poly_shader: &poly_shader,
key: &shader_final_key,
};
let tag_index = TagIndex::new_expr(shader_final.tag, shader_final.index);
let v = poly_shader
.get(tag_index)
.dispatch(|_, _, shader| shader.evaluate(x, &ctx));
result.var().write(i, v);
}),
);
kernel.dispatch([100, 1, 1]);
let result = result.copy_to_vec();
for i in 0..100 {
let x = i as f32 / 100.0 * PI;
let v = x.sin() + (1.0 + 2.0);
assert!((result[i] - v).abs() < 1e-5);
}
println!("OK");
}