Skip to content

Commit

Permalink
Add edges to mesh again
Browse files Browse the repository at this point in the history
  • Loading branch information
mhochsteger committed Oct 18, 2024
1 parent 3ae9e8c commit 44c56b1
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 47 deletions.
27 changes: 17 additions & 10 deletions utils/generate_interpolation_shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def getReferenceRules(order, sd):
return res


_eval_trig_template = """
fn evalTrig{suffix}(id: u32, icomp: u32, lam: vec2<f32>) -> f32 {{
let order = u32(trig_function_values[1]);
let ncomp = u32(trig_function_values[0]);
let ndof = (order + 1) * (order + 2) / 2;
_eval_template = """
fn eval{eltype}{suffix}(id: u32, icomp: u32, lam: {lam_type}) -> f32 {{
let order: u32 = u32(trig_function_values[1]);
let ncomp: u32 = u32(trig_function_values[0]);
let ndof: u32 = {ndof_expr};

let offset = ndof * id + VALUES_OFFSET;
let offset: u32 = ndof * id + VALUES_OFFSET;
let stride: u32 = ncomp;

{switch_order}
Expand Down Expand Up @@ -199,6 +199,15 @@ def GenerateInterpolationFunction(et, orders, scal_dims):
ET.PRISM: "Prism",
ET.PYRAMID: "Pyramid",
}[et]
ndof_expr = {
ET.SEGM: "order+1",
ET.TRIG: "(order+1)*(order+2)/2",
ET.TET: "(order+1)*(order+2)*(order+3)/6",
ET.QUAD: "(order+1)*(order+1)",
ET.HEX: "(order+1)*(order+1)*(order+1)",
ET.PRISM: "(order+1)*(order+1)*(order+2)/2",
ET.PYRAMID: "",
}[et]
result = ""
for p in orders:
print("\n\n=============", eltype, p)
Expand Down Expand Up @@ -296,14 +305,12 @@ def code_get_vec(dim, i=0):
orders_ = sorted(list(set(list(orders) + [1])))
for p in orders_:
switch_order += f" if order == {p} {{ return eval{eltype}P{p}{suffix}(offset, stride, lam); }}\n"
result += _eval_trig_template.format(
switch_order=switch_order, p=p, suffix=suffix
)
result += _eval_template.format(**locals())
return result


code = ""
for et in [ET.SEGM, ET.TRIG, ET.TET][1:2]:
for et in [ET.SEGM, ET.TRIG, ET.TET][0:2]:
code += GenerateInterpolationFunction(et, orders=range(1, 7), scal_dims=range(1, 2))

open("../webgpu/eval.wgsl", "w").write(code)
113 changes: 109 additions & 4 deletions webgpu/eval.wgsl
Original file line number Diff line number Diff line change
@@ -1,3 +1,108 @@
fn evalSegP1Basis(x: f32) -> array<f32, 2> {
let y = 1.0 - x;
return array(x, y);
}

fn evalSegP1(offset: u32, stride: u32, lam: f32) -> f32 {
let basis = evalSegP1Basis(lam);
var result: f32 = basis[0] * seg_function_values[offset + 0 * stride];
result += basis[1] * seg_function_values[offset + 1 * stride];
return result;
}

fn evalSegP2Basis(x: f32) -> array<f32, 3> {
let y = 1.0 - x;
return array(x * x, 2.0 * x * y, y * y);
}

fn evalSegP2(offset: u32, stride: u32, lam: f32) -> f32 {
let basis = evalSegP2Basis(lam);
var result: f32 = basis[0] * seg_function_values[offset + 0 * stride];
result += basis[1] * seg_function_values[offset + 1 * stride];
result += basis[2] * seg_function_values[offset + 2 * stride];
return result;
}

fn evalSegP3Basis(x: f32) -> array<f32, 4> {
let y = 1.0 - x;
return array(x * x * x, 3.0 * x * x * y, 3.0 * x * y * y, y * y * y);
}

fn evalSegP3(offset: u32, stride: u32, lam: f32) -> f32 {
let basis = evalSegP3Basis(lam);
var result: f32 = basis[0] * seg_function_values[offset + 0 * stride];
result += basis[1] * seg_function_values[offset + 1 * stride];
result += basis[2] * seg_function_values[offset + 2 * stride];
result += basis[3] * seg_function_values[offset + 3 * stride];
return result;
}

fn evalSegP4Basis(x: f32) -> array<f32, 5> {
let y = 1.0 - x;
return array(x * x * x * x, 4.0 * x * x * x * y, 6.0 * x * x * y * y, 4.0 * x * y * y * y, y * y * y * y);
}

fn evalSegP4(offset: u32, stride: u32, lam: f32) -> f32 {
let basis = evalSegP4Basis(lam);
var result: f32 = basis[0] * seg_function_values[offset + 0 * stride];
result += basis[1] * seg_function_values[offset + 1 * stride];
result += basis[2] * seg_function_values[offset + 2 * stride];
result += basis[3] * seg_function_values[offset + 3 * stride];
result += basis[4] * seg_function_values[offset + 4 * stride];
return result;
}

fn evalSegP5Basis(x: f32) -> array<f32, 6> {
let y = 1.0 - x;
return array(x * x * x * x * x, 5.0 * x * x * x * x * y, 10.0 * x * x * x * y * y, 10.0 * x * x * y * y * y, 5.0 * x * y * y * y * y, y * y * y * y * y);
}

fn evalSegP5(offset: u32, stride: u32, lam: f32) -> f32 {
let basis = evalSegP5Basis(lam);
var result: f32 = basis[0] * seg_function_values[offset + 0 * stride];
result += basis[1] * seg_function_values[offset + 1 * stride];
result += basis[2] * seg_function_values[offset + 2 * stride];
result += basis[3] * seg_function_values[offset + 3 * stride];
result += basis[4] * seg_function_values[offset + 4 * stride];
result += basis[5] * seg_function_values[offset + 5 * stride];
return result;
}

fn evalSegP6Basis(x: f32) -> array<f32, 7> {
let y = 1.0 - x;
return array(x * x * x * x * x * x, 6.0 * x * x * x * x * x * y, 15.0 * x * x * x * x * y * y, 20.0 * x * x * x * y * y * y, 15.0 * x * x * y * y * y * y, 6.0 * x * y * y * y * y * y, y * y * y * y * y * y);
}

fn evalSegP6(offset: u32, stride: u32, lam: f32) -> f32 {
let basis = evalSegP6Basis(lam);
var result: f32 = basis[0] * seg_function_values[offset + 0 * stride];
result += basis[1] * seg_function_values[offset + 1 * stride];
result += basis[2] * seg_function_values[offset + 2 * stride];
result += basis[3] * seg_function_values[offset + 3 * stride];
result += basis[4] * seg_function_values[offset + 4 * stride];
result += basis[5] * seg_function_values[offset + 5 * stride];
result += basis[6] * seg_function_values[offset + 6 * stride];
return result;
}


fn evalSeg(id: u32, icomp: u32, lam: f32) -> f32 {
let order: u32 = u32(trig_function_values[1]);
let ncomp: u32 = u32(trig_function_values[0]);
let ndof: u32 = order + 1;

let offset: u32 = ndof * id + VALUES_OFFSET;
let stride: u32 = ncomp;

if order == 1 { return evalSegP1(offset, stride, lam); }
if order == 2 { return evalSegP2(offset, stride, lam); }
if order == 3 { return evalSegP3(offset, stride, lam); }
if order == 4 { return evalSegP4(offset, stride, lam); }
if order == 5 { return evalSegP5(offset, stride, lam); }
if order == 6 { return evalSegP6(offset, stride, lam); }

return 0.0;
}
fn evalTrigP1Basis(lam: vec2<f32>) -> array<f32, 3> {
let x = lam.x;
let y = lam.y;
Expand Down Expand Up @@ -201,11 +306,11 @@ fn evalTrigP6(offset: u32, stride: u32, lam: vec2<f32>) -> f32 {


fn evalTrig(id: u32, icomp: u32, lam: vec2<f32>) -> f32 {
let order = u32(trig_function_values[1]);
let ncomp = u32(trig_function_values[0]);
let ndof = (order + 1) * (order + 2) / 2;
let order: u32 = u32(trig_function_values[1]);
let ncomp: u32 = u32(trig_function_values[0]);
let ndof: u32 = (order + 1) * (order + 2) / 2;

let offset = ndof * id + VALUES_OFFSET;
let offset: u32 = ndof * id + VALUES_OFFSET;
let stride: u32 = ncomp;

if order == 1 { return evalTrigP1(offset, stride, lam); }
Expand Down
2 changes: 1 addition & 1 deletion webgpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def main():

gpu = await init_webgpu(js.document.getElementById("canvas"))

mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.1))
mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.5))
order = 6
gfu = ngs.GridFunction(ngs.H1(mesh, order=order))
# gfu.Set(ngs.IfPos(ngs.x-0.8, 1, 0))
Expand Down
70 changes: 43 additions & 27 deletions webgpu/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,26 @@ def _create_pipelines(self):
)
self._create_pipeline_layout()
shader_module = self.gpu.device.createShaderModule(to_js({"code": shader_code}))
# edges_pipeline = self.gpu.device.createRenderPipeline(
# to_js(
# {
# "layout": self._pipeline_layout,
# "vertex": {
# "module": shader_module,
# "entryPoint": "mainVertexEdge",
# },
# "fragment": {
# "module": shader_module,
# "entryPoint": "mainFragmentEdge",
# "targets": [{"format": self.gpu.format}],
# },
# "primitive": {"topology": "line-list"},
# "depthStencil": {
# **self.gpu.depth_stencil,
# },
# }
# )
# )
edges_pipeline = self.gpu.device.createRenderPipeline(
to_js(
{
"layout": self._pipeline_layout,
"vertex": {
"module": shader_module,
"entryPoint": "mainVertexEdgeP1",
},
"fragment": {
"module": shader_module,
"entryPoint": "mainFragmentEdge",
"targets": [{"format": self.gpu.format}],
},
"primitive": {"topology": "line-list"},
"depthStencil": {
**self.gpu.depth_stencil,
},
}
)
)

trigs_pipeline = self.gpu.device.createRenderPipeline(
to_js(
Expand Down Expand Up @@ -133,14 +133,14 @@ def _create_pipelines(self):
)

self.pipelines = {
# "edges": edges_pipeline,
"edges": edges_pipeline,
"trigs": trigs_pipeline,
}

def render(self, encoder):
# encoder.setPipeline(self.pipelines["edges"])
# encoder.setBindGroup(0, self._bind_group)
# encoder.draw(2, self.n_edges, 0, 0)
encoder.setPipeline(self.pipelines["edges"])
encoder.setBindGroup(0, self._bind_group)
encoder.draw(2, 3 * self.n_trigs, 0, 0)

encoder.setPipeline(self.pipelines["trigs"])
encoder.setBindGroup(0, self._bind_group)
Expand Down Expand Up @@ -183,6 +183,22 @@ def create_mesh_buffers(device, region, curve_order=1):

n_trigs = len(mesh.ngmesh.Elements2D())

edge_points = points[2:].reshape(-1, 3, 3)
edges = np.zeros((n_trigs, 3, 2, 3), dtype=np.float32)
for i in range(3):
edges[:, i, 0, :] = edge_points[:, i, :]
edges[:, i, 1, :] = edge_points[:, (i + 1) % 3, :]
edge_data = js.Uint8Array.new(edges.flatten().tobytes())
edge_buffer = device.createBuffer(
to_js(
{
"size": edge_data.length,
"usage": js.GPUBufferUsage.STORAGE | js.GPUBufferUsage.COPY_DST,
}
)
)
device.queue.writeBuffer(edge_buffer, 0, edge_data)

trigs = np.zeros(
n_trigs,
dtype=[
Expand All @@ -194,16 +210,16 @@ def create_mesh_buffers(device, region, curve_order=1):
trigs["index"] = [1] * n_trigs
data = js.Uint8Array.new(trigs.tobytes())

buffer = device.createBuffer(
trigs_buffer = device.createBuffer(
to_js(
{
"size": data.length,
"usage": js.GPUBufferUsage.STORAGE | js.GPUBufferUsage.COPY_DST,
}
)
)
device.queue.writeBuffer(buffer, 0, data)
return {"trigs": buffer}
device.queue.writeBuffer(trigs_buffer, 0, data)
return {"trigs": trigs_buffer, "edges": edge_buffer}


def create_function_value_buffers(device, cf, region, order):
Expand Down
22 changes: 17 additions & 5 deletions webgpu/shader.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
struct Edge { v: vec2<u32> };
struct EdgeP1 { p: array<f32, 6> };

struct Segment { v: vec2<u32>, index: i32 };
struct TrigP1 { p: array<f32, 9>, index: i32 }; // 3 vertices with 3 coordinates each, don't use vec3 due to 16 byte alignment
struct TrigP2 { p: array<f32, 18>, index: i32 };

Expand All @@ -21,9 +20,10 @@ const VALUES_OFFSET: u32 = 2; // storing number of components and order of basis
@group(0) @binding(1) var colormap : texture_1d<f32>;
@group(0) @binding(2) var colormap_sampler : sampler;

// @group(0) @binding(4) var<storage> edges : array<Edge>;
@group(0) @binding(4) var<storage> edges_p1 : array<EdgeP1>;
@group(0) @binding(5) var<storage> trigs_p1 : array<TrigP1>;
@group(0) @binding(6) var<storage> trig_function_values : array<f32>;
@group(0) @binding(7) var<storage> seg_function_values : array<f32>;

struct VertexOutput1d {
@builtin(position) fragPosition: vec4<f32>,
Expand Down Expand Up @@ -63,9 +63,21 @@ fn getColor(value: f32) -> vec4<f32> {
}

@vertex
fn mainVertexTrigP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) trigId: u32) -> VertexOutput2d {
let v_offset = 3 * (3 * trigId + vertexId) + VALUES_OFFSET;
fn mainVertexEdgeP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) edgeId: u32) -> VertexOutput1d {
let edge = edges_p1[edgeId];
var p: vec3<f32> = vec3<f32>(edge.p[3 * vertexId], edge.p[3 * vertexId + 1], edge.p[3 * vertexId + 2]);

var lam: f32 = 0.0;
if vertexId == 0 {
lam = 1.0;
}

var position = calcPosition(p);
return VertexOutput1d(position, p, lam, edgeId);
}

@vertex
fn mainVertexTrigP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) trigId: u32) -> VertexOutput2d {
let trig = trigs_p1[trigId];
var p = vec3<f32>(trig.p[3 * vertexId], trig.p[3 * vertexId + 1], trig.p[3 * vertexId + 2]);

Expand Down

0 comments on commit 44c56b1

Please sign in to comment.