Skip to content

Commit

Permalink
Splat scale is stored in texture instead of VB
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Valigursky committed Oct 20, 2023
1 parent 6517e31 commit 3ca582c
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 135 deletions.
152 changes: 73 additions & 79 deletions src/splat/splat-material.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,24 @@ import {
SEMANTIC_POSITION,
SEMANTIC_ATTR11,
SEMANTIC_ATTR12,
SEMANTIC_ATTR13,
SEMANTIC_ATTR14
SEMANTIC_ATTR13
} from "playcanvas";

const sharedShader = `
const splatVS = `
attribute vec3 vertex_position;
attribute vec3 splat_center;
attribute vec3 splat_rotation;
uniform mat4 matrix_model;
uniform mat4 matrix_view;
uniform mat4 matrix_projection;
uniform mat4 matrix_viewProjection;
uniform vec2 viewport;
varying vec2 texCoord;
varying vec4 color;
mat3 quatToMat3(vec3 R)
{
float x = R.x;
Expand All @@ -35,29 +48,40 @@ const sharedShader = `
);
}
#ifdef WEBGPU
attribute uint vertex_id;
#else
attribute float vertex_id;
#endif
uniform vec4 tex_params;
uniform sampler2D splatColor;
uniform highp sampler2D splatScale;
#ifdef WEBGPU
ivec2 getTextureCoords() {
attribute uint vertex_id;
ivec2 dataUV;
void evalDataUV() {
// turn vertex_id into int grid coordinates
ivec2 textureSize = ivec2(tex_params.xy);
vec2 invTextureSize = tex_params.zw;
int gridV = int(float(vertex_id) * invTextureSize.x);
int gridU = int(vertex_id - gridV * textureSize.x);
return ivec2(gridU, gridV);
dataUV = ivec2(gridU, gridV);
}
vec4 getColor() {
return texelFetch(splatColor, dataUV, 0);
}
vec3 getScale() {
return texelFetch(splatScale, dataUV, 0).xyz;
}
#else
vec2 getTextureCoords() {
// TODO: use texture2DLodEXT on WebGL
attribute float vertex_id;
vec2 dataUV;
void evalDataUV() {
vec2 textureSize = tex_params.xy;
vec2 invTextureSize = tex_params.zw;
Expand All @@ -66,37 +90,18 @@ const sharedShader = `
float gridU = vertex_id - (gridV * textureSize.x);
// convert grid coordinates to uv coordinates with half pixel offset
return vec2(gridU, gridV) * invTextureSize + (0.5 * invTextureSize);
dataUV = vec2(gridU, gridV) * invTextureSize + (0.5 * invTextureSize);
}
#endif
vec4 getColor() {
#ifdef WEBGPU
ivec2 textureUV = getTextureCoords();
return texelFetch(splatColor, ivec2(textureUV), 0);
#else
vec2 textureUV = getTextureCoords();
return texture2D(splatColor, textureUV);
#endif
}
`;

const splatVS = `
attribute vec2 vertex_position;
attribute vec3 splat_center;
attribute vec3 splat_rotation;
attribute vec3 splat_scale;
uniform mat4 matrix_model;
uniform mat4 matrix_view;
uniform mat4 matrix_projection;
uniform vec2 viewport;
varying vec2 texCoord;
varying vec4 color;
vec4 getColor() {
return texture(splatColor, dataUV);
}
${sharedShader}
vec3 getScale() {
return texture(splatScale, dataUV).xyz;
}
#endif
void computeCov3d(in vec3 rot, in vec3 scale, out vec3 covA, out vec3 covB)
{
Expand Down Expand Up @@ -130,6 +135,8 @@ const splatVS = `
void main(void)
{
evalDataUV();
vec4 splat_cam = matrix_view * matrix_model * vec4(splat_center, 1.0);
vec4 splat_proj = matrix_projection * splat_cam;
Expand All @@ -141,7 +148,8 @@ const splatVS = `
vec3 splat_cova;
vec3 splat_covb;
computeCov3d(splat_rotation, splat_scale, splat_cova, splat_covb);
vec3 scale = getScale();
computeCov3d(splat_rotation, scale, splat_cova, splat_covb);
mat3 Vrk = mat3(
splat_cova.x, splat_cova.y, splat_cova.z,
Expand Down Expand Up @@ -177,9 +185,18 @@ const splatVS = `
vec4((vertex_position.x * v1 + vertex_position.y * v2) / viewport * 2.0,
0.0, 0.0) * splat_proj.w;
texCoord = vertex_position * 2.0;
texCoord = vertex_position.xy * 2.0;
color = getColor();
#ifdef DEBUG_RENDER
vec3 local = quatToMat3(splat_rotation) * (vertex_position * scale * 2.0) + splat_center;
gl_Position = matrix_viewProjection * matrix_model * vec4(local, 1.0);
color = getColor();
#endif
}
`;

Expand All @@ -189,42 +206,19 @@ const splatFS = /* glsl_ */ `
void main(void)
{
float A = -dot(texCoord, texCoord);
if (A < -4.0) discard;
float B = exp(A) * color.a;
gl_FragColor = vec4(color.rgb, B);
}
`;
#ifdef DEBUG_RENDER
const splatDebugVS = /* glsl_ */ `
attribute vec3 vertex_position;
attribute vec3 splat_center;
attribute vec3 splat_rotation;
attribute vec3 splat_scale;
uniform mat4 matrix_model;
uniform mat4 matrix_viewProjection;
if (color.a < 0.2) discard;
gl_FragColor = color;
varying vec4 color;
${sharedShader}
void main(void)
{
vec3 local = quatToMat3(splat_rotation) * (vertex_position * splat_scale * 2.0) + splat_center;
gl_Position = matrix_viewProjection * matrix_model * vec4(local, 1.0);
#else
color = getColor();
}
`;
float A = -dot(texCoord, texCoord);
if (A < -4.0) discard;
float B = exp(A) * color.a;
gl_FragColor = vec4(color.rgb, B);
const splatDebugFS = /* glsl_ */ `
varying vec4 color;
void main(void)
{
if (color.a < 0.2) discard;
gl_FragColor = color;
#endif
}
`;

Expand All @@ -235,15 +229,15 @@ const createSplatMaterial = (device: GraphicsDevice, debugRender = false) => {
result.blendType = BLEND_NORMAL;
result.depthWrite = false;

const vs = debugRender ? splatDebugVS : splatVS;
const fs = debugRender ? splatDebugFS : splatFS;
const defines = debugRender ? '#define DEBUG_RENDER\n' : '';
const vs = defines + splatVS;
const fs = defines + splatFS;

result.shader = createShaderFromCode(device, vs, fs, 'splatShader', {
result.shader = createShaderFromCode(device, vs, fs, `splatShader-${debugRender}`, {
vertex_position: SEMANTIC_POSITION,
splat_center: SEMANTIC_ATTR11,
splat_rotation: SEMANTIC_ATTR12,
splat_scale: SEMANTIC_ATTR13,
vertex_id: SEMANTIC_ATTR14
vertex_id: SEMANTIC_ATTR13
});

result.update();
Expand Down
Loading

0 comments on commit 3ca582c

Please sign in to comment.