diff --git a/utils/generate_interpolation_shader.py b/utils/generate_interpolation_shader.py index 03cf3fd..3d3f3d1 100644 --- a/utils/generate_interpolation_shader.py +++ b/utils/generate_interpolation_shader.py @@ -42,13 +42,13 @@ def getReferenceRules(order, sd): return res -_eval_trig_template = """ -fn evalTrig{suffix}(id: u32, icomp: u32, lam: vec2) -> f32 {{ - let order = u32(trig_function_values[1]); - let ncomp = u32(trig_function_values[0]); - let ndof = (order + 1) * (order + 2) / 2; +_eval_template = """ +fn eval{eltype}{suffix}(id: u32, icomp: u32, lam: {lam_type}) -> f32 {{ + let order: u32 = u32(trig_function_values[1]); + let ncomp: u32 = u32(trig_function_values[0]); + let ndof: u32 = {ndof_expr}; - let offset = ndof * id + VALUES_OFFSET; + let offset: u32 = ndof * id + VALUES_OFFSET; let stride: u32 = ncomp; {switch_order} @@ -199,6 +199,15 @@ def GenerateInterpolationFunction(et, orders, scal_dims): ET.PRISM: "Prism", ET.PYRAMID: "Pyramid", }[et] + ndof_expr = { + ET.SEGM: "order+1", + ET.TRIG: "(order+1)*(order+2)/2", + ET.TET: "(order+1)*(order+2)*(order+3)/6", + ET.QUAD: "(order+1)*(order+1)", + ET.HEX: "(order+1)*(order+1)*(order+1)", + ET.PRISM: "(order+1)*(order+1)*(order+2)/2", + ET.PYRAMID: "", + }[et] result = "" for p in orders: print("\n\n=============", eltype, p) @@ -296,14 +305,12 @@ def code_get_vec(dim, i=0): orders_ = sorted(list(set(list(orders) + [1]))) for p in orders_: switch_order += f" if order == {p} {{ return eval{eltype}P{p}{suffix}(offset, stride, lam); }}\n" - result += _eval_trig_template.format( - switch_order=switch_order, p=p, suffix=suffix - ) + result += _eval_template.format(**locals()) return result code = "" -for et in [ET.SEGM, ET.TRIG, ET.TET][1:2]: +for et in [ET.SEGM, ET.TRIG, ET.TET][0:2]: code += GenerateInterpolationFunction(et, orders=range(1, 7), scal_dims=range(1, 2)) open("../webgpu/eval.wgsl", "w").write(code) diff --git a/webgpu/eval.wgsl b/webgpu/eval.wgsl index b01e42c..7236ce6 100644 --- a/webgpu/eval.wgsl +++ b/webgpu/eval.wgsl @@ -1,3 +1,108 @@ +fn evalSegP1Basis(x: f32) -> array { + let y = 1.0 - x; + return array(x, y); +} + +fn evalSegP1(offset: u32, stride: u32, lam: f32) -> f32 { + let basis = evalSegP1Basis(lam); + var result: f32 = basis[0] * seg_function_values[offset + 0 * stride]; + result += basis[1] * seg_function_values[offset + 1 * stride]; + return result; +} + +fn evalSegP2Basis(x: f32) -> array { + let y = 1.0 - x; + return array(x * x, 2.0 * x * y, y * y); +} + +fn evalSegP2(offset: u32, stride: u32, lam: f32) -> f32 { + let basis = evalSegP2Basis(lam); + var result: f32 = basis[0] * seg_function_values[offset + 0 * stride]; + result += basis[1] * seg_function_values[offset + 1 * stride]; + result += basis[2] * seg_function_values[offset + 2 * stride]; + return result; +} + +fn evalSegP3Basis(x: f32) -> array { + let y = 1.0 - x; + return array(x * x * x, 3.0 * x * x * y, 3.0 * x * y * y, y * y * y); +} + +fn evalSegP3(offset: u32, stride: u32, lam: f32) -> f32 { + let basis = evalSegP3Basis(lam); + var result: f32 = basis[0] * seg_function_values[offset + 0 * stride]; + result += basis[1] * seg_function_values[offset + 1 * stride]; + result += basis[2] * seg_function_values[offset + 2 * stride]; + result += basis[3] * seg_function_values[offset + 3 * stride]; + return result; +} + +fn evalSegP4Basis(x: f32) -> array { + let y = 1.0 - x; + return array(x * x * x * x, 4.0 * x * x * x * y, 6.0 * x * x * y * y, 4.0 * x * y * y * y, y * y * y * y); +} + +fn evalSegP4(offset: u32, stride: u32, lam: f32) -> f32 { + let basis = evalSegP4Basis(lam); + var result: f32 = basis[0] * seg_function_values[offset + 0 * stride]; + result += basis[1] * seg_function_values[offset + 1 * stride]; + result += basis[2] * seg_function_values[offset + 2 * stride]; + result += basis[3] * seg_function_values[offset + 3 * stride]; + result += basis[4] * seg_function_values[offset + 4 * stride]; + return result; +} + +fn evalSegP5Basis(x: f32) -> array { + let y = 1.0 - x; + return array(x * x * x * x * x, 5.0 * x * x * x * x * y, 10.0 * x * x * x * y * y, 10.0 * x * x * y * y * y, 5.0 * x * y * y * y * y, y * y * y * y * y); +} + +fn evalSegP5(offset: u32, stride: u32, lam: f32) -> f32 { + let basis = evalSegP5Basis(lam); + var result: f32 = basis[0] * seg_function_values[offset + 0 * stride]; + result += basis[1] * seg_function_values[offset + 1 * stride]; + result += basis[2] * seg_function_values[offset + 2 * stride]; + result += basis[3] * seg_function_values[offset + 3 * stride]; + result += basis[4] * seg_function_values[offset + 4 * stride]; + result += basis[5] * seg_function_values[offset + 5 * stride]; + return result; +} + +fn evalSegP6Basis(x: f32) -> array { + let y = 1.0 - x; + return array(x * x * x * x * x * x, 6.0 * x * x * x * x * x * y, 15.0 * x * x * x * x * y * y, 20.0 * x * x * x * y * y * y, 15.0 * x * x * y * y * y * y, 6.0 * x * y * y * y * y * y, y * y * y * y * y * y); +} + +fn evalSegP6(offset: u32, stride: u32, lam: f32) -> f32 { + let basis = evalSegP6Basis(lam); + var result: f32 = basis[0] * seg_function_values[offset + 0 * stride]; + result += basis[1] * seg_function_values[offset + 1 * stride]; + result += basis[2] * seg_function_values[offset + 2 * stride]; + result += basis[3] * seg_function_values[offset + 3 * stride]; + result += basis[4] * seg_function_values[offset + 4 * stride]; + result += basis[5] * seg_function_values[offset + 5 * stride]; + result += basis[6] * seg_function_values[offset + 6 * stride]; + return result; +} + + +fn evalSeg(id: u32, icomp: u32, lam: f32) -> f32 { + let order: u32 = u32(trig_function_values[1]); + let ncomp: u32 = u32(trig_function_values[0]); + let ndof: u32 = order + 1; + + let offset: u32 = ndof * id + VALUES_OFFSET; + let stride: u32 = ncomp; + + if order == 1 { return evalSegP1(offset, stride, lam); } + if order == 2 { return evalSegP2(offset, stride, lam); } + if order == 3 { return evalSegP3(offset, stride, lam); } + if order == 4 { return evalSegP4(offset, stride, lam); } + if order == 5 { return evalSegP5(offset, stride, lam); } + if order == 6 { return evalSegP6(offset, stride, lam); } + + return 0.0; +} fn evalTrigP1Basis(lam: vec2) -> array { let x = lam.x; let y = lam.y; @@ -201,11 +306,11 @@ fn evalTrigP6(offset: u32, stride: u32, lam: vec2) -> f32 { fn evalTrig(id: u32, icomp: u32, lam: vec2) -> f32 { - let order = u32(trig_function_values[1]); - let ncomp = u32(trig_function_values[0]); - let ndof = (order + 1) * (order + 2) / 2; + let order: u32 = u32(trig_function_values[1]); + let ncomp: u32 = u32(trig_function_values[0]); + let ndof: u32 = (order + 1) * (order + 2) / 2; - let offset = ndof * id + VALUES_OFFSET; + let offset: u32 = ndof * id + VALUES_OFFSET; let stride: u32 = ncomp; if order == 1 { return evalTrigP1(offset, stride, lam); } diff --git a/webgpu/main.py b/webgpu/main.py index 66bfe62..6ea1320 100644 --- a/webgpu/main.py +++ b/webgpu/main.py @@ -17,7 +17,7 @@ async def main(): gpu = await init_webgpu(js.document.getElementById("canvas")) - mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.1)) + mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.5)) order = 6 gfu = ngs.GridFunction(ngs.H1(mesh, order=order)) # gfu.Set(ngs.IfPos(ngs.x-0.8, 1, 0)) diff --git a/webgpu/mesh.py b/webgpu/mesh.py index e29e9a1..a262fd6 100644 --- a/webgpu/mesh.py +++ b/webgpu/mesh.py @@ -83,26 +83,26 @@ def _create_pipelines(self): ) self._create_pipeline_layout() shader_module = self.gpu.device.createShaderModule(to_js({"code": shader_code})) - # edges_pipeline = self.gpu.device.createRenderPipeline( - # to_js( - # { - # "layout": self._pipeline_layout, - # "vertex": { - # "module": shader_module, - # "entryPoint": "mainVertexEdge", - # }, - # "fragment": { - # "module": shader_module, - # "entryPoint": "mainFragmentEdge", - # "targets": [{"format": self.gpu.format}], - # }, - # "primitive": {"topology": "line-list"}, - # "depthStencil": { - # **self.gpu.depth_stencil, - # }, - # } - # ) - # ) + edges_pipeline = self.gpu.device.createRenderPipeline( + to_js( + { + "layout": self._pipeline_layout, + "vertex": { + "module": shader_module, + "entryPoint": "mainVertexEdgeP1", + }, + "fragment": { + "module": shader_module, + "entryPoint": "mainFragmentEdge", + "targets": [{"format": self.gpu.format}], + }, + "primitive": {"topology": "line-list"}, + "depthStencil": { + **self.gpu.depth_stencil, + }, + } + ) + ) trigs_pipeline = self.gpu.device.createRenderPipeline( to_js( @@ -133,14 +133,14 @@ def _create_pipelines(self): ) self.pipelines = { - # "edges": edges_pipeline, + "edges": edges_pipeline, "trigs": trigs_pipeline, } def render(self, encoder): - # encoder.setPipeline(self.pipelines["edges"]) - # encoder.setBindGroup(0, self._bind_group) - # encoder.draw(2, self.n_edges, 0, 0) + encoder.setPipeline(self.pipelines["edges"]) + encoder.setBindGroup(0, self._bind_group) + encoder.draw(2, 3 * self.n_trigs, 0, 0) encoder.setPipeline(self.pipelines["trigs"]) encoder.setBindGroup(0, self._bind_group) @@ -183,6 +183,22 @@ def create_mesh_buffers(device, region, curve_order=1): n_trigs = len(mesh.ngmesh.Elements2D()) + edge_points = points[2:].reshape(-1, 3, 3) + edges = np.zeros((n_trigs, 3, 2, 3), dtype=np.float32) + for i in range(3): + edges[:, i, 0, :] = edge_points[:, i, :] + edges[:, i, 1, :] = edge_points[:, (i + 1) % 3, :] + edge_data = js.Uint8Array.new(edges.flatten().tobytes()) + edge_buffer = device.createBuffer( + to_js( + { + "size": edge_data.length, + "usage": js.GPUBufferUsage.STORAGE | js.GPUBufferUsage.COPY_DST, + } + ) + ) + device.queue.writeBuffer(edge_buffer, 0, edge_data) + trigs = np.zeros( n_trigs, dtype=[ @@ -194,7 +210,7 @@ def create_mesh_buffers(device, region, curve_order=1): trigs["index"] = [1] * n_trigs data = js.Uint8Array.new(trigs.tobytes()) - buffer = device.createBuffer( + trigs_buffer = device.createBuffer( to_js( { "size": data.length, @@ -202,8 +218,8 @@ def create_mesh_buffers(device, region, curve_order=1): } ) ) - device.queue.writeBuffer(buffer, 0, data) - return {"trigs": buffer} + device.queue.writeBuffer(trigs_buffer, 0, data) + return {"trigs": trigs_buffer, "edges": edge_buffer} def create_function_value_buffers(device, cf, region, order): diff --git a/webgpu/shader.wgsl b/webgpu/shader.wgsl index e8cb440..3a66a2b 100644 --- a/webgpu/shader.wgsl +++ b/webgpu/shader.wgsl @@ -1,6 +1,5 @@ -struct Edge { v: vec2 }; +struct EdgeP1 { p: array }; -struct Segment { v: vec2, index: i32 }; struct TrigP1 { p: array, index: i32 }; // 3 vertices with 3 coordinates each, don't use vec3 due to 16 byte alignment struct TrigP2 { p: array, index: i32 }; @@ -21,9 +20,10 @@ const VALUES_OFFSET: u32 = 2; // storing number of components and order of basis @group(0) @binding(1) var colormap : texture_1d; @group(0) @binding(2) var colormap_sampler : sampler; -// @group(0) @binding(4) var edges : array; +@group(0) @binding(4) var edges_p1 : array; @group(0) @binding(5) var trigs_p1 : array; @group(0) @binding(6) var trig_function_values : array; +@group(0) @binding(7) var seg_function_values : array; struct VertexOutput1d { @builtin(position) fragPosition: vec4, @@ -63,9 +63,21 @@ fn getColor(value: f32) -> vec4 { } @vertex -fn mainVertexTrigP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) trigId: u32) -> VertexOutput2d { - let v_offset = 3 * (3 * trigId + vertexId) + VALUES_OFFSET; +fn mainVertexEdgeP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) edgeId: u32) -> VertexOutput1d { + let edge = edges_p1[edgeId]; + var p: vec3 = vec3(edge.p[3 * vertexId], edge.p[3 * vertexId + 1], edge.p[3 * vertexId + 2]); + + var lam: f32 = 0.0; + if vertexId == 0 { + lam = 1.0; + } + + var position = calcPosition(p); + return VertexOutput1d(position, p, lam, edgeId); +} +@vertex +fn mainVertexTrigP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) trigId: u32) -> VertexOutput2d { let trig = trigs_p1[trigId]; var p = vec3(trig.p[3 * vertexId], trig.p[3 * vertexId + 1], trig.p[3 * vertexId + 2]);