Skip to content

Commit

Permalink
loop unrolling in shader
Browse files Browse the repository at this point in the history
  • Loading branch information
mhochsteger committed Oct 18, 2024
1 parent 67cc738 commit 3ae9e8c
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 100 deletions.
17 changes: 7 additions & 10 deletions utils/generate_interpolation_shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ def to_ccode(code):
print(ndof + basis_code.count("*"), "multiplications")
print(ndof + basis_code.count("+"), "additions")

def code_get_vec(dim, offset="i*stride"):
def code_get_vec(dim, i=0):
if dim == 1:
return f"{eltype.lower()}_function_values[offset+{offset}]"
return f"{eltype.lower()}_function_values[offset+{i}*stride]"
code = f"vec{dim}<f32>("
for i in range(dim):
code += f"{eltype.lower()}_function_values[offset+{offset}+{i}]"
code += f"{eltype.lower()}_function_values[offset+{i}*stride]"
if i < dim - 1:
code += ", "
code += ")"
Expand All @@ -278,12 +278,9 @@ def code_get_vec(dim, offset="i*stride"):

code += f"fn eval{eltype}P{p}{suffix}(offset: u32, stride: u32, lam: {lam_type}) -> {scal} {{\n"
code += f" let basis = eval{eltype}P{p}Basis(lam);\n"
code += (
f" var result: {scal} = basis[0] * {code_get_vec(scal_dim, '0')};\n"
)
code += f" for (var i: u32 = 1; i < {ndof}; i++) {{\n"
code += f" result += basis[i] * {code_get_vec(scal_dim)};\n"
code += f" }}\n"
code += f" var result: {scal} = basis[0] * {code_get_vec(scal_dim)};\n"
for i in range(1, ndof):
code += f" result += basis[{i}] * {code_get_vec(scal_dim, i)};\n"
code += f" return result;\n"
code += f"}}\n\n"
result += code
Expand All @@ -307,6 +304,6 @@ def code_get_vec(dim, offset="i*stride"):

code = ""
for et in [ET.SEGM, ET.TRIG, ET.TET][1:2]:
code += GenerateInterpolationFunction(et, orders=range(2, 7), scal_dims=range(1, 2))
code += GenerateInterpolationFunction(et, orders=range(1, 7), scal_dims=range(1, 2))

open("../webgpu/eval.wgsl", "w").write(code)
227 changes: 151 additions & 76 deletions webgpu/eval.wgsl
Original file line number Diff line number Diff line change
@@ -1,126 +1,201 @@
fn evalTrigP1Basis(lam: vec2<f32>) -> array<f32, 3> {
let x = lam.x;
let y = lam.y;
let z = 1.0 - x - y;
return array(x, y, z);
}

fn evalTrigP1(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {
let basis = evalTrigP1Basis(lam);
var result: f32 = basis[0] * trig_function_values[offset + 0 * stride];
result += basis[1] * trig_function_values[offset + 1 * stride];
result += basis[2] * trig_function_values[offset + 2 * stride];
return result;
}

fn evalTrigP2Basis(lam: vec2<f32>) -> array<f32, 6> {
let x = lam.x;
let y = lam.y;
let z = 1.0 - x-y;
let x0 = 2.0*x;
return array(x*x, x0*y, y*y, x0*z, 2.0*y*z, z*z);
let z = 1.0 - x - y;
let x0 = 2.0 * x;
return array(x * x, x0 * y, y * y, x0 * z, 2.0 * y * z, z * z);
}

fn evalTrigP2(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {
let basis = evalTrigP2Basis(lam);
var result: f32 = basis[0] * trig_function_values[offset+0];
for (var i: u32 = 1; i < 6; i++) {
result += basis[i] * trig_function_values[offset+i*stride];
}
var result: f32 = basis[0] * trig_function_values[offset + 0 * stride];
result += basis[1] * trig_function_values[offset + 1 * stride];
result += basis[2] * trig_function_values[offset + 2 * stride];
result += basis[3] * trig_function_values[offset + 3 * stride];
result += basis[4] * trig_function_values[offset + 4 * stride];
result += basis[5] * trig_function_values[offset + 5 * stride];
return result;
}

fn evalTrigP3Basis(lam: vec2<f32>) -> array<f32, 10> {
let x = lam.x;
let y = lam.y;
let z = 1.0 - x-y;
let x0 = 3.0*x*x;
let x1 = 3.0*y*y;
let x2 = 3.0*z*z;
return array(x*x*x, x0*y, x*x1, y*y*y, x0*z, 6.0*x*y*z, x1*z, x*x2, x2*y, z*z*z);
let z = 1.0 - x - y;
let x0 = 3.0 * x * x;
let x1 = 3.0 * y * y;
let x2 = 3.0 * z * z;
return array(x * x * x, x0 * y, x * x1, y * y * y, x0 * z, 6.0 * x * y * z, x1 * z, x * x2, x2 * y, z * z * z);
}

fn evalTrigP3(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {
let basis = evalTrigP3Basis(lam);
var result: f32 = basis[0] * trig_function_values[offset+0];
for (var i: u32 = 1; i < 10; i++) {
result += basis[i] * trig_function_values[offset+i*stride];
}
var result: f32 = basis[0] * trig_function_values[offset + 0 * stride];
result += basis[1] * trig_function_values[offset + 1 * stride];
result += basis[2] * trig_function_values[offset + 2 * stride];
result += basis[3] * trig_function_values[offset + 3 * stride];
result += basis[4] * trig_function_values[offset + 4 * stride];
result += basis[5] * trig_function_values[offset + 5 * stride];
result += basis[6] * trig_function_values[offset + 6 * stride];
result += basis[7] * trig_function_values[offset + 7 * stride];
result += basis[8] * trig_function_values[offset + 8 * stride];
result += basis[9] * trig_function_values[offset + 9 * stride];
return result;
}

fn evalTrigP4Basis(lam: vec2<f32>) -> array<f32, 15> {
let x = lam.x;
let y = lam.y;
let z = 1.0 - x-y;
let x0 = 4.0*x*x*x;
let x1 = y*y;
let x2 = x*x;
let x3 = 6.0*x2;
let x4 = 4.0*y*y*y;
let x5 = 12.0*z;
let x6 = z*z;
let x7 = 4.0*z*z*z;
return array(x*x*x*x, x0*y, x1*x3, x*x4, y*y*y*y, x0*z, x2*x5*y, x*x1*x5, x4*z, x3*x6, 12.0*x*x6*y, 6.0*x1*x6, x*x7, x7*y, z*z*z*z);
let z = 1.0 - x - y;
let x0 = 4.0 * x * x * x;
let x1 = y * y;
let x2 = x * x;
let x3 = 6.0 * x2;
let x4 = 4.0 * y * y * y;
let x5 = 12.0 * z;
let x6 = z * z;
let x7 = 4.0 * z * z * z;
return array(x * x * x * x, x0 * y, x1 * x3, x * x4, y * y * y * y, x0 * z, x2 * x5 * y, x * x1 * x5, x4 * z, x3 * x6, 12.0 * x * x6 * y, 6.0 * x1 * x6, x * x7, x7 * y, z * z * z * z);
}

fn evalTrigP4(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {
let basis = evalTrigP4Basis(lam);
var result: f32 = basis[0] * trig_function_values[offset+0];
for (var i: u32 = 1; i < 15; i++) {
result += basis[i] * trig_function_values[offset+i*stride];
}
var result: f32 = basis[0] * trig_function_values[offset + 0 * stride];
result += basis[1] * trig_function_values[offset + 1 * stride];
result += basis[2] * trig_function_values[offset + 2 * stride];
result += basis[3] * trig_function_values[offset + 3 * stride];
result += basis[4] * trig_function_values[offset + 4 * stride];
result += basis[5] * trig_function_values[offset + 5 * stride];
result += basis[6] * trig_function_values[offset + 6 * stride];
result += basis[7] * trig_function_values[offset + 7 * stride];
result += basis[8] * trig_function_values[offset + 8 * stride];
result += basis[9] * trig_function_values[offset + 9 * stride];
result += basis[10] * trig_function_values[offset + 10 * stride];
result += basis[11] * trig_function_values[offset + 11 * stride];
result += basis[12] * trig_function_values[offset + 12 * stride];
result += basis[13] * trig_function_values[offset + 13 * stride];
result += basis[14] * trig_function_values[offset + 14 * stride];
return result;
}

fn evalTrigP5Basis(lam: vec2<f32>) -> array<f32, 21> {
let x = lam.x;
let y = lam.y;
let z = 1.0 - x-y;
let x0 = 5.0*x*x*x*x;
let x1 = y*y;
let x2 = x*x*x;
let x3 = 10.0*x2;
let x4 = x*x;
let x5 = y*y*y;
let x6 = 10.0*x5;
let x7 = 5.0*y*y*y*y;
let x8 = 20.0*z;
let x9 = 30.0*x4;
let x10 = z*z;
let x11 = z*z*z;
let x12 = 10.0*x11;
let x13 = 5.0*z*z*z*z;
return array(x*x*x*x*x, x0*y, x1*x3, x4*x6, x*x7, y*y*y*y*y, x0*z, x2*x8*y, x1*x9*z, x*x5*x8, x7*z, x10*x3, x10*x9*y, 30.0*x*x1*x10, x10*x6, x12*x4, 20.0*x*x11*y, x1*x12, x*x13, x13*y, z*z*z*z*z);
let z = 1.0 - x - y;
let x0 = 5.0 * x * x * x * x;
let x1 = y * y;
let x2 = x * x * x;
let x3 = 10.0 * x2;
let x4 = x * x;
let x5 = y * y * y;
let x6 = 10.0 * x5;
let x7 = 5.0 * y * y * y * y;
let x8 = 20.0 * z;
let x9 = 30.0 * x4;
let x10 = z * z;
let x11 = z * z * z;
let x12 = 10.0 * x11;
let x13 = 5.0 * z * z * z * z;
return array(x * x * x * x * x, x0 * y, x1 * x3, x4 * x6, x * x7, y * y * y * y * y, x0 * z, x2 * x8 * y, x1 * x9 * z, x * x5 * x8, x7 * z, x10 * x3, x10 * x9 * y, 30.0 * x * x1 * x10, x10 * x6, x12 * x4, 20.0 * x * x11 * y, x1 * x12, x * x13, x13 * y, z * z * z * z * z);
}

fn evalTrigP5(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {
let basis = evalTrigP5Basis(lam);
var result: f32 = basis[0] * trig_function_values[offset+0];
for (var i: u32 = 1; i < 21; i++) {
result += basis[i] * trig_function_values[offset+i*stride];
}
var result: f32 = basis[0] * trig_function_values[offset + 0 * stride];
result += basis[1] * trig_function_values[offset + 1 * stride];
result += basis[2] * trig_function_values[offset + 2 * stride];
result += basis[3] * trig_function_values[offset + 3 * stride];
result += basis[4] * trig_function_values[offset + 4 * stride];
result += basis[5] * trig_function_values[offset + 5 * stride];
result += basis[6] * trig_function_values[offset + 6 * stride];
result += basis[7] * trig_function_values[offset + 7 * stride];
result += basis[8] * trig_function_values[offset + 8 * stride];
result += basis[9] * trig_function_values[offset + 9 * stride];
result += basis[10] * trig_function_values[offset + 10 * stride];
result += basis[11] * trig_function_values[offset + 11 * stride];
result += basis[12] * trig_function_values[offset + 12 * stride];
result += basis[13] * trig_function_values[offset + 13 * stride];
result += basis[14] * trig_function_values[offset + 14 * stride];
result += basis[15] * trig_function_values[offset + 15 * stride];
result += basis[16] * trig_function_values[offset + 16 * stride];
result += basis[17] * trig_function_values[offset + 17 * stride];
result += basis[18] * trig_function_values[offset + 18 * stride];
result += basis[19] * trig_function_values[offset + 19 * stride];
result += basis[20] * trig_function_values[offset + 20 * stride];
return result;
}

fn evalTrigP6Basis(lam: vec2<f32>) -> array<f32, 28> {
let x = lam.x;
let y = lam.y;
let z = 1.0 - x-y;
let x0 = 6.0*x*x*x*x*x;
let x1 = y*y;
let x2 = x*x*x*x;
let x3 = 15.0*x2;
let x4 = y*y*y;
let x5 = x*x*x;
let x6 = 20.0*x5;
let x7 = x*x;
let x8 = y*y*y*y;
let x9 = 15.0*x8;
let x10 = 6.0*y*y*y*y*y;
let x11 = 30.0*z;
let x12 = 60.0*z;
let x13 = z*z;
let x14 = 60.0*x13;
let x15 = z*z*z;
let x16 = 60.0*x15;
let x17 = z*z*z*z;
let x18 = 15.0*x17;
let x19 = 6.0*z*z*z*z*z;
return array(x*x*x*x*x*x, x0*y, x1*x3, x4*x6, x7*x9, x*x10, y*y*y*y*y*y, x0*z, x11*x2*y, x1*x12*x5, x12*x4*x7, x*x11*x8, x10*z, x13*x3, x14*x5*y, 90.0*x1*x13*x7, x*x14*x4, x13*x9, x15*x6, x16*x7*y, x*x1*x16, 20.0*x15*x4, x18*x7, 30.0*x*x17*y, x1*x18, x*x19, x19*y, z*z*z*z*z*z);
let z = 1.0 - x - y;
let x0 = 6.0 * x * x * x * x * x;
let x1 = y * y;
let x2 = x * x * x * x;
let x3 = 15.0 * x2;
let x4 = y * y * y;
let x5 = x * x * x;
let x6 = 20.0 * x5;
let x7 = x * x;
let x8 = y * y * y * y;
let x9 = 15.0 * x8;
let x10 = 6.0 * y * y * y * y * y;
let x11 = 30.0 * z;
let x12 = 60.0 * z;
let x13 = z * z;
let x14 = 60.0 * x13;
let x15 = z * z * z;
let x16 = 60.0 * x15;
let x17 = z * z * z * z;
let x18 = 15.0 * x17;
let x19 = 6.0 * z * z * z * z * z;
return array(x * x * x * x * x * x, x0 * y, x1 * x3, x4 * x6, x7 * x9, x * x10, y * y * y * y * y * y, x0 * z, x11 * x2 * y, x1 * x12 * x5, x12 * x4 * x7, x * x11 * x8, x10 * z, x13 * x3, x14 * x5 * y, 90.0 * x1 * x13 * x7, x * x14 * x4, x13 * x9, x15 * x6, x16 * x7 * y, x * x1 * x16, 20.0 * x15 * x4, x18 * x7, 30.0 * x * x17 * y, x1 * x18, x * x19, x19 * y, z * z * z * z * z * z);
}

fn evalTrigP6(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {
let basis = evalTrigP6Basis(lam);
var result: f32 = basis[0] * trig_function_values[offset+0];
for (var i: u32 = 1; i < 28; i++) {
result += basis[i] * trig_function_values[offset+i*stride];
}
var result: f32 = basis[0] * trig_function_values[offset + 0 * stride];
result += basis[1] * trig_function_values[offset + 1 * stride];
result += basis[2] * trig_function_values[offset + 2 * stride];
result += basis[3] * trig_function_values[offset + 3 * stride];
result += basis[4] * trig_function_values[offset + 4 * stride];
result += basis[5] * trig_function_values[offset + 5 * stride];
result += basis[6] * trig_function_values[offset + 6 * stride];
result += basis[7] * trig_function_values[offset + 7 * stride];
result += basis[8] * trig_function_values[offset + 8 * stride];
result += basis[9] * trig_function_values[offset + 9 * stride];
result += basis[10] * trig_function_values[offset + 10 * stride];
result += basis[11] * trig_function_values[offset + 11 * stride];
result += basis[12] * trig_function_values[offset + 12 * stride];
result += basis[13] * trig_function_values[offset + 13 * stride];
result += basis[14] * trig_function_values[offset + 14 * stride];
result += basis[15] * trig_function_values[offset + 15 * stride];
result += basis[16] * trig_function_values[offset + 16 * stride];
result += basis[17] * trig_function_values[offset + 17 * stride];
result += basis[18] * trig_function_values[offset + 18 * stride];
result += basis[19] * trig_function_values[offset + 19 * stride];
result += basis[20] * trig_function_values[offset + 20 * stride];
result += basis[21] * trig_function_values[offset + 21 * stride];
result += basis[22] * trig_function_values[offset + 22 * stride];
result += basis[23] * trig_function_values[offset + 23 * stride];
result += basis[24] * trig_function_values[offset + 24 * stride];
result += basis[25] * trig_function_values[offset + 25 * stride];
result += basis[26] * trig_function_values[offset + 26 * stride];
result += basis[27] * trig_function_values[offset + 27 * stride];
return result;
}

Expand Down
14 changes: 0 additions & 14 deletions webgpu/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,8 @@ fn mainVertexTrigP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_inde

@fragment
fn mainFragmentTrig(@location(0) p: vec3<f32>, @location(1) lam: vec2<f32>, @location(2) id: u32) -> @location(0) vec4<f32> {
// return vec4<f32>(lam, 1.0-lam.x-lam.y, 1.0);
checkClipping(p);

let value = evalTrig(id, 0u, lam);
// let value = 0.1;
return getColor(value);
}

Expand All @@ -95,14 +92,3 @@ fn mainFragmentEdge(@location(0) p: vec3<f32>) -> @location(0) vec4<f32> {
return vec4<f32>(0, 0, 0, 1.0);
}

fn evalSegP1(values: array<f32, 2>, lam: f32) -> f32 {
return mix(values[0], values[1], lam);
}

fn evalTrigP1(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {
return trig_function_values[offset] * lam.x + trig_function_values[offset + stride] * lam.y + trig_function_values[offset + 2 * stride] * (1.0 - lam.x - lam.y);
}

fn evalTetP1(values: array<f32, 4>, lam: vec3<f32>) -> f32 {
return values[0] * lam.x + values[1] * lam.y + values[2] * lam.z + values[3] * (1.0 - lam.x - lam.y - lam.z);
}

0 comments on commit 3ae9e8c

Please sign in to comment.