diff --git a/src/splat/splat-material.ts b/src/splat/splat-material.ts index 4a806fd..43b41d9 100644 --- a/src/splat/splat-material.ts +++ b/src/splat/splat-material.ts @@ -8,11 +8,24 @@ import { SEMANTIC_POSITION, SEMANTIC_ATTR11, SEMANTIC_ATTR12, - SEMANTIC_ATTR13, - SEMANTIC_ATTR14 + SEMANTIC_ATTR13 } from "playcanvas"; -const sharedShader = ` +const splatVS = ` + attribute vec3 vertex_position; + attribute vec3 splat_center; + attribute vec3 splat_rotation; + + uniform mat4 matrix_model; + uniform mat4 matrix_view; + uniform mat4 matrix_projection; + uniform mat4 matrix_viewProjection; + + uniform vec2 viewport; + + varying vec2 texCoord; + varying vec4 color; + mat3 quatToMat3(vec3 R) { float x = R.x; @@ -35,17 +48,15 @@ const sharedShader = ` ); } - #ifdef WEBGPU - attribute uint vertex_id; - #else - attribute float vertex_id; - #endif - uniform vec4 tex_params; uniform sampler2D splatColor; + uniform highp sampler2D splatScale; #ifdef WEBGPU - ivec2 getTextureCoords() { + + attribute uint vertex_id; + ivec2 dataUV; + void evalDataUV() { // turn vertex_id into int grid coordinates ivec2 textureSize = ivec2(tex_params.xy); @@ -53,11 +64,24 @@ const sharedShader = ` int gridV = int(float(vertex_id) * invTextureSize.x); int gridU = int(vertex_id - gridV * textureSize.x); - return ivec2(gridU, gridV); + dataUV = ivec2(gridU, gridV); + } + + vec4 getColor() { + return texelFetch(splatColor, dataUV, 0); } + vec3 getScale() { + return texelFetch(splatScale, dataUV, 0).xyz; + } + #else - vec2 getTextureCoords() { + + // TODO: use texture2DLodEXT on WebGL + + attribute float vertex_id; + vec2 dataUV; + void evalDataUV() { vec2 textureSize = tex_params.xy; vec2 invTextureSize = tex_params.zw; @@ -66,37 +90,18 @@ const sharedShader = ` float gridU = vertex_id - (gridV * textureSize.x); // convert grid coordinates to uv coordinates with half pixel offset - return vec2(gridU, gridV) * invTextureSize + (0.5 * invTextureSize); + dataUV = vec2(gridU, gridV) * invTextureSize + (0.5 * invTextureSize); } - #endif - - vec4 getColor() { - #ifdef WEBGPU - ivec2 textureUV = getTextureCoords(); - return texelFetch(splatColor, ivec2(textureUV), 0); - #else - vec2 textureUV = getTextureCoords(); - return texture2D(splatColor, textureUV); - #endif - } -`; - -const splatVS = ` - attribute vec2 vertex_position; - attribute vec3 splat_center; - attribute vec3 splat_rotation; - attribute vec3 splat_scale; - uniform mat4 matrix_model; - uniform mat4 matrix_view; - uniform mat4 matrix_projection; - - uniform vec2 viewport; - - varying vec2 texCoord; - varying vec4 color; + vec4 getColor() { + return texture(splatColor, dataUV); + } - ${sharedShader} + vec3 getScale() { + return texture(splatScale, dataUV).xyz; + } + + #endif void computeCov3d(in vec3 rot, in vec3 scale, out vec3 covA, out vec3 covB) { @@ -130,6 +135,8 @@ const splatVS = ` void main(void) { + evalDataUV(); + vec4 splat_cam = matrix_view * matrix_model * vec4(splat_center, 1.0); vec4 splat_proj = matrix_projection * splat_cam; @@ -141,7 +148,8 @@ const splatVS = ` vec3 splat_cova; vec3 splat_covb; - computeCov3d(splat_rotation, splat_scale, splat_cova, splat_covb); + vec3 scale = getScale(); + computeCov3d(splat_rotation, scale, splat_cova, splat_covb); mat3 Vrk = mat3( splat_cova.x, splat_cova.y, splat_cova.z, @@ -177,9 +185,18 @@ const splatVS = ` vec4((vertex_position.x * v1 + vertex_position.y * v2) / viewport * 2.0, 0.0, 0.0) * splat_proj.w; - texCoord = vertex_position * 2.0; + texCoord = vertex_position.xy * 2.0; color = getColor(); + + #ifdef DEBUG_RENDER + + vec3 local = quatToMat3(splat_rotation) * (vertex_position * scale * 2.0) + splat_center; + gl_Position = matrix_viewProjection * matrix_model * vec4(local, 1.0); + + color = getColor(); + + #endif } `; @@ -189,42 +206,19 @@ const splatFS = /* glsl_ */ ` void main(void) { - float A = -dot(texCoord, texCoord); - if (A < -4.0) discard; - float B = exp(A) * color.a; - gl_FragColor = vec4(color.rgb, B); - } -`; + #ifdef DEBUG_RENDER -const splatDebugVS = /* glsl_ */ ` - attribute vec3 vertex_position; - attribute vec3 splat_center; - attribute vec3 splat_rotation; - attribute vec3 splat_scale; - - uniform mat4 matrix_model; - uniform mat4 matrix_viewProjection; + if (color.a < 0.2) discard; + gl_FragColor = color; - varying vec4 color; - - ${sharedShader} - - void main(void) - { - vec3 local = quatToMat3(splat_rotation) * (vertex_position * splat_scale * 2.0) + splat_center; - gl_Position = matrix_viewProjection * matrix_model * vec4(local, 1.0); + #else - color = getColor(); - } -`; + float A = -dot(texCoord, texCoord); + if (A < -4.0) discard; + float B = exp(A) * color.a; + gl_FragColor = vec4(color.rgb, B); -const splatDebugFS = /* glsl_ */ ` - varying vec4 color; - - void main(void) - { - if (color.a < 0.2) discard; - gl_FragColor = color; + #endif } `; @@ -235,15 +229,15 @@ const createSplatMaterial = (device: GraphicsDevice, debugRender = false) => { result.blendType = BLEND_NORMAL; result.depthWrite = false; - const vs = debugRender ? splatDebugVS : splatVS; - const fs = debugRender ? splatDebugFS : splatFS; + const defines = debugRender ? '#define DEBUG_RENDER\n' : ''; + const vs = defines + splatVS; + const fs = defines + splatFS; - result.shader = createShaderFromCode(device, vs, fs, 'splatShader', { + result.shader = createShaderFromCode(device, vs, fs, `splatShader-${debugRender}`, { vertex_position: SEMANTIC_POSITION, splat_center: SEMANTIC_ATTR11, splat_rotation: SEMANTIC_ATTR12, - splat_scale: SEMANTIC_ATTR13, - vertex_id: SEMANTIC_ATTR14 + vertex_id: SEMANTIC_ATTR13 }); result.update(); diff --git a/src/splat/splat.ts b/src/splat/splat.ts index a583515..94e2466 100644 --- a/src/splat/splat.ts +++ b/src/splat/splat.ts @@ -1,6 +1,3 @@ -// internally stores mesh, and material with shader attached and all that -// SplatResource can call this, wrap it in entity / render component and return that - import { Material, MeshInstance, @@ -18,14 +15,16 @@ import { SEMANTIC_ATTR11, SEMANTIC_ATTR12, SEMANTIC_ATTR13, - SEMANTIC_ATTR14, TYPE_FLOAT32, VertexFormat, TYPE_UINT32, BUFFER_DYNAMIC, VertexBuffer, BoundingBox, - Mat4 + Mat4, + PIXELFORMAT_RGBA16F, + PIXELFORMAT_RGB32F, + PIXELFORMAT_RGBA32F } from "playcanvas"; import { SplatData } from "./splat-data"; import { SortWorker } from "./sort-worker"; @@ -35,6 +34,52 @@ import { createSplatMaterial } from "./splat-material"; const debugRender = false; const debugRenderBounds = false; +const floatView = new Float32Array(1); +const int32View = new Int32Array(floatView.buffer); + +const float2Half = (value: number) => { + // based on https://esdiscuss.org/topic/float16array + // This method is faster than the OpenEXR implementation (very often + // used, eg. in Ogre), with the additional benefit of rounding, inspired + // by James Tursa?s half-precision code. + floatView[0] = value; + const x = int32View[0]; + + let bits = (x >> 16) & 0x8000; // Get the sign + let m = (x >> 12) & 0x07ff; // Keep one extra bit for rounding + const e = (x >> 23) & 0xff; // Using int is faster here + + // If zero, or denormal, or exponent underflows too much for a denormal half, return signed zero. + if (e < 103) { + return bits; + } + + // If NaN, return NaN. If Inf or exponent overflow, return Inf. + if (e > 142) { + bits |= 0x7c00; + + // If exponent was 0xff and one mantissa bit was set, it means NaN, + // not Inf, so make sure we set one mantissa bit too. + bits |= ((e === 255) ? 0 : 1) && (x & 0x007fffff); + return bits; + } + + // If exponent underflows but not too much, return a denormal + if (e < 113) { + m |= 0x0800; + + // Extra rounding may overflow and set mantissa to 0 and exponent to 1, which is OK. + bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); + return bits; + } + + bits |= ((e - 112) << 10) | (m >> 1); + + // Extra rounding. An overflow will set mantissa to 0 and increment the exponent, which is OK. + bits += m & 1; + return bits; +}; + class Splat { device: GraphicsDevice; material: Material; @@ -42,13 +87,31 @@ class Splat { quadMesh: Mesh; aabb = new BoundingBox(); focalPoint = new Vec3(); + halfFormat: number; + floatFormat: number; constructor(device: GraphicsDevice) { this.device = device; + this.testTextureFormats(); + this.material = createSplatMaterial(device, debugRender); this.createMesh(); } + testTextureFormats() { + const { device } = this; + this.halfFormat = (device.extTextureHalfFloat && device.textureHalfFloatUpdatable) ? PIXELFORMAT_RGBA16F : undefined; + this.floatFormat = device.extTextureFloat ? PIXELFORMAT_RGB32F : undefined; + + if (device.isWebGPU) { + this.floatFormat = PIXELFORMAT_RGBA32F; + } + } + + getTextureFormat(preferHighPrecision: boolean) { + return preferHighPrecision ? (this.floatFormat ?? this.halfFormat) : (this.halfFormat ?? this.floatFormat); + } + createMesh() { if (debugRender) { this.quadMesh = createBox(this.device, { @@ -84,20 +147,90 @@ class Splat { }); } - create(splatData: SplatData, options: any) { - const x = splatData.getProp('x'); - const y = splatData.getProp('y'); - const z = splatData.getProp('z'); + createColorTexture(splatData: SplatData, size: Vec2) { + + const SH_C0 = 0.28209479177387814; const f_dc_0 = splatData.getProp('f_dc_0'); const f_dc_1 = splatData.getProp('f_dc_1'); const f_dc_2 = splatData.getProp('f_dc_2'); - const opacity = splatData.getProp('opacity'); - const scale_0 = splatData.getProp('scale_0'); - const scale_1 = splatData.getProp('scale_1'); - const scale_2 = splatData.getProp('scale_2'); + const texture = this.createTexture('splatColor', PIXELFORMAT_RGBA8, size); + const data = texture.lock(); + + const sigmoid = (v: number) => { + if (v > 0) { + return 1 / (1 + Math.exp(-v)); + } + + const t = Math.exp(v); + return t / (1 + t); + }; + + for (let i = 0; i < splatData.numSplats; ++i) { + + // colors + if (f_dc_0 && f_dc_1 && f_dc_2) { + data[i * 4 + 0] = math.clamp((0.5 + SH_C0 * f_dc_0[i]) * 255, 0, 255); + data[i * 4 + 1] = math.clamp((0.5 + SH_C0 * f_dc_1[i]) * 255, 0, 255); + data[i * 4 + 2] = math.clamp((0.5 + SH_C0 * f_dc_2[i]) * 255, 0, 255); + } + + // opacity + data[i * 4 + 3] = opacity ? math.clamp(sigmoid(opacity[i]) * 255, 0, 255) : 255; + } + + texture.unlock(); + return texture; + } + + createScaleTexture(splatData: SplatData, size: Vec2, format: number) { + + // texture format based vars + let halfFloat = false; + let numComponents = 3; // RGB32 is used + if (format === PIXELFORMAT_RGBA16F) { + halfFloat = true; + numComponents = 4; // RGBA16 is used, RGB16 does not work + } + + if (format === PIXELFORMAT_RGBA32F) { + numComponents = 4; + } + + const scale0 = splatData.getProp('scale_0'); + const scale1 = splatData.getProp('scale_1'); + const scale2 = splatData.getProp('scale_2'); + + const texture = this.createTexture('splatScale', format, size); + const data = texture.lock(); + + for (let i = 0; i < splatData.numSplats; i++) { + + const sx = Math.exp(scale0[i]); + const sy = Math.exp(scale1[i]); + const sz = Math.exp(scale2[i]); + + if (halfFloat) { + data[i * numComponents + 0] = float2Half(sx); + data[i * numComponents + 1] = float2Half(sy); + data[i * numComponents + 2] = float2Half(sz); + } else { + data[i * numComponents + 0] = sx; + data[i * numComponents + 1] = sy; + data[i * numComponents + 2] = sz; + } + } + + texture.unlock(); + return texture; + } + + create(splatData: SplatData, options: any) { + const x = splatData.getProp('x'); + const y = splatData.getProp('y'); + const z = splatData.getProp('z'); const rot_0 = splatData.getProp('rot_0'); const rot_1 = splatData.getProp('rot_1'); @@ -108,13 +241,13 @@ class Splat { return; } - const stride = 10; + const stride = 7; const textureSize = this.evalTextureSize(splatData.numSplats); - const colorTexture = this.createTexture('splatColor', PIXELFORMAT_RGBA8, textureSize); - const colorData = colorTexture.lock(); + const colorTexture = this.createColorTexture(splatData, textureSize); + const scaleTexture = this.createScaleTexture(splatData, textureSize, this.getTextureFormat(true)); - // position.xyz, color, rotation.xyz, scale.xyz + // position.xyz, rotation.xyz const floatData = new Float32Array(splatData.numSplats * stride); const uint32Data = new Uint32Array(floatData.buffer); @@ -129,34 +262,6 @@ class Splat { floatData[i * stride + 1] = y[j]; floatData[i * stride + 2] = z[j]; - // vertex colors - if (f_dc_0 && f_dc_1 && f_dc_2) { - const SH_C0 = 0.28209479177387814; - const r = math.clamp((0.5 + SH_C0 * f_dc_0[j]) * 255, 0, 255); - const g = math.clamp((0.5 + SH_C0 * f_dc_1[j]) * 255, 0, 255); - const b = math.clamp((0.5 + SH_C0 * f_dc_2[j]) * 255, 0, 255); - - colorData[i * 4 + 0] = r; - colorData[i * 4 + 1] = g; - colorData[i * 4 + 2] = b; - } - - // opacity - if (opacity) { - const sigmoid = (v: number) => { - if (v > 0) { - return 1 / (1 + Math.exp(-v)); - } - - const t = Math.exp(v); - return t / (1 + t); - }; - const a = sigmoid(opacity[j]) * 255; - colorData[i * 4 + 3] = a; - } else { - colorData[i * 4 + 3] = 255; - } - quat.set(rot_0[j], rot_1[j], rot_2[j], rot_3[j]).normalize(); // rotation @@ -170,29 +275,23 @@ class Splat { floatData[i * stride + 5] = quat.z; } - // scale - floatData[i * stride + 6] = Math.exp(scale_0[j]); - floatData[i * stride + 7] = Math.exp(scale_1[j]); - floatData[i * stride + 8] = Math.exp(scale_2[j]); - // index if (isWebGPU) { - uint32Data[i * stride + 9] = i; + uint32Data[i * stride + 6] = i; } else { - floatData[i * stride + 9] = i + 0.2; + floatData[i * stride + 6] = i + 0.2; } } - colorTexture.unlock(); this.material.setParameter('splatColor', colorTexture); + this.material.setParameter('splatScale', scaleTexture); this.material.setParameter('tex_params', new Float32Array([textureSize.x, textureSize.y, 1 / textureSize.x, 1 / textureSize.y])); // create instance data const vertexFormat = new VertexFormat(this.device, [ { semantic: SEMANTIC_ATTR11, components: 3, type: TYPE_FLOAT32 }, { semantic: SEMANTIC_ATTR12, components: 3, type: TYPE_FLOAT32 }, - { semantic: SEMANTIC_ATTR13, components: 3, type: TYPE_FLOAT32 }, - { semantic: SEMANTIC_ATTR14, components: 1, type: isWebGPU ? TYPE_UINT32 : TYPE_FLOAT32 } + { semantic: SEMANTIC_ATTR13, components: 1, type: isWebGPU ? TYPE_UINT32 : TYPE_FLOAT32 } ]); const vertexBuffer = new VertexBuffer(this.device, vertexFormat, splatData.numSplats, BUFFER_DYNAMIC, floatData.buffer); @@ -256,7 +355,6 @@ class Splat { // calculate focal point splatData.calcFocalPoint(this.focalPoint); - } }