Skip to content

Commit

Permalink
Optimized splat sorting to make it 3x faster (#270)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Valigursky <[email protected]>
  • Loading branch information
mvaligursky and Martin Valigursky authored Nov 1, 2023
1 parent cebf613 commit f900c13
Showing 1 changed file with 35 additions and 77 deletions.
112 changes: 35 additions & 77 deletions src/splat/sort-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ function SortWorker() {
const epsilon = 0.0001;

// number of bits used to store the distance in integer array. Smaller number gives it a smaller
// precision but less radix sort passes to sort. Could even be dynamic for less precise sorting.
// precision but faster sorting. Could be dynamic for less precise sorting.
// 16bit seems plenty of large scenes (train), 10bits is enough for sled.
const compareBits = 16;

// larger based makes a lot less passes by radix sort, but each pass is slightly slower. Big win
// to use 512 vs 10. Needs to find a the sweet spot for this.
const radixBase = 512;
// number of buckets for count sorting to represent each unique distance using compareBits bits
const bucketCount = (2 ** compareBits) + 1;

let data: Float32Array;
let centers: Float32Array;
Expand All @@ -31,71 +30,10 @@ function SortWorker() {
const lastCameraPosition = { x: 0, y: 0, z: 0 };
const lastCameraDirection = { x: 0, y: 0, z: 0 };

let orderBuffer: BigUint64Array;
let orderBuffer32: Uint32Array;
let orderBufferTmp: BigUint64Array;
let distances: Uint32Array;
let indices: Uint32Array;
let target: Float32Array;

// A function to do counting sort of arr[] according to the digit represented by exp.
const countSort = (arr: BigUint64Array, arr32: Uint32Array, temp: BigUint64Array, n: number, exp: number, intIndices: boolean, outputArray: any) => {
const count = new Array(radixBase);
for (let i = 0; i < radixBase; i++)
count[i] = 0;

// Store count of occurrences in count[]
for (let i = 0; i < n; i++) {
const x = Math.floor(arr32[i * 2 + 1] / exp) % radixBase;
count[x]++;
}

// Change count[i] so that count[i] now contains actual position of this digit in output[]
for (let i = 1; i < radixBase; i++)
count[i] += count[i - 1];

// Build the output array
for (let i = n - 1; i >= 0; i--) {
const x = Math.floor(arr32[i * 2 + 1] / exp) % radixBase;
temp[count[x] - 1] = arr[i];
count[x]--;
}

// if outputting directly to final array, avoid the copy to temp array
if (outputArray) {

const temp32 = new Uint32Array(temp.buffer);
if (intIndices) {

for (let i = 0; i < n; i++)
outputArray[i] = temp32[i * 2];

} else {

for (let i = 0; i < n; i++)
outputArray[i] = temp32[i * 2] + 0.2;
}

} else {

// Copy the output array to arr[], so that arr[] now contains sorted numbers according to current digit
for (let i = 0; i < n; i++)
arr[i] = temp[i];
}
};

// The main function to that sorts arr[] of size n using Radix Sort
const radixSort = (arr: BigUint64Array, arr32: Uint32Array, arrTmp: BigUint64Array, n: number, intIndices: boolean, finalArray: any) => {

// maximum number to know number of digits
const m = 2 ** compareBits;

// Do counting sort for every digit. Note that instead of passing digit number, exp is passed.
// exp is 10^i where i is current digit number
for (let exp = 1; Math.floor(m / exp) > 0; exp *= radixBase) {

const lastPass = Math.floor(m / (exp * radixBase)) === 0;
countSort(arr, arr32, arrTmp, n, exp, intIndices, lastPass ? finalArray : null);
}
};
let countBuffer: Uint32Array;

const update = () => {
if (!centers || !data || !cameraPosition || !cameraDirection) return;
Expand Down Expand Up @@ -127,10 +65,9 @@ function SortWorker() {
const numVertices = centers.length / 3;

// create distance buffer
if (orderBuffer?.length !== numVertices) {
orderBuffer = new BigUint64Array(numVertices);
orderBuffer32 = new Uint32Array(orderBuffer.buffer);
orderBufferTmp = new BigUint64Array(numVertices);
if (distances?.length !== numVertices) {
distances = new Uint32Array(numVertices);
indices = new Uint32Array(numVertices);
target = new Float32Array(numVertices);
}

Expand All @@ -150,6 +87,12 @@ function SortWorker() {
}
}

if (!countBuffer)
countBuffer = new Uint32Array(bucketCount);

for (let i = 0; i < bucketCount; i++)
countBuffer[i] = 0;

// generate per vertex distance to camera
const range = maxDist - minDist;
const divider = 1 / range * (2 ** compareBits);
Expand All @@ -158,13 +101,28 @@ function SortWorker() {
const d = (centers[istride + 0] - px) * dx +
(centers[istride + 1] - py) * dy +
(centers[istride + 2] - pz) * dz;
orderBuffer32[i * 2 + 0] = i;
orderBuffer32[i * 2 + 1] = Math.floor((d - minDist) * divider);
const sortKey = Math.floor((d - minDist) * divider);

distances[i] = sortKey;
indices[i] = i;

// count occurrences of each distance
countBuffer[sortKey]++;
}

// sort indices by distance only, so use distance in orderBuffer32 as sorting key
const finalArray = intIndices ? new Uint32Array(target.buffer) : target;
radixSort(orderBuffer, orderBuffer32, orderBufferTmp, numVertices, intIndices, finalArray);
// Change countBuffer[i] so that it contains actual position of this digit in outputArray
for (let i = 1; i < bucketCount; i++)
countBuffer[i] += countBuffer[i - 1];

// Build the output array
const outputArray = intIndices ? new Uint32Array(target.buffer) : target;
const offset = intIndices ? 0 : 0.2;
for (let i = numVertices - 1; i >= 0; i--) {
const distance = distances[i];
const index = indices[i];
outputArray[countBuffer[distance] - 1] = index + offset;
countBuffer[distance]--;
}

// swap
const tmp = data;
Expand Down

0 comments on commit f900c13

Please sign in to comment.