Skip to content

Optimize backward propagation kernel (41% end-to-end speedup on the example) #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
build/
diff_gaussian_rasterization.egg-info/
dist/
__pycache__/
_C.cpython*
9 changes: 5 additions & 4 deletions cuda_rasterizer/auxiliary.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#include "stdio.h"

#define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
#define NUM_WARPS (BLOCK_SIZE/32)
#define WARP_SIZE 32
#define NUM_WARPS (BLOCK_SIZE/WARP_SIZE)

// Spherical harmonics coefficients
__device__ const float SH_C0 = 0.28209479177387814f;
Expand Down Expand Up @@ -99,15 +100,15 @@ __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, cons
__forceinline__ __device__ float dnormvdz(float3 v, float3 dv)
{
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
float invsum32 = rsqrtf(sum2 * sum2 * sum2);
float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
return dnormvdz;
}

__forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
{
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
float invsum32 = rsqrtf(sum2 * sum2 * sum2);

float3 dnormvdv;
dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32;
Expand All @@ -119,7 +120,7 @@ __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
__forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv)
{
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
float invsum32 = rsqrtf(sum2 * sum2 * sum2);

float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w };
float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
Expand Down
256 changes: 188 additions & 68 deletions cuda_rasterizer/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ __global__ void preprocessCUDA(
}

// Backward version of the rendering procedure.
#define USE_ATOMIC_THRESHOLD 10
template <uint32_t C>
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
renderCUDA(
Expand All @@ -408,6 +409,7 @@ renderCUDA(
const float* __restrict__ colors,
const float* __restrict__ final_Ts,
const uint32_t* __restrict__ n_contrib,
const uint32_t* __restrict__ tiles_touched,
const float* __restrict__ dL_dpixels,
float3* __restrict__ dL_dmean2D,
float4* __restrict__ dL_dconic2D,
Expand All @@ -428,10 +430,10 @@ renderCUDA(

const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);

bool done = !inside;
int toDo = range.y - range.x;

__shared__ int collected_id[BLOCK_SIZE];
__shared__ bool collected_use_atomic[BLOCK_SIZE];
__shared__ float2 collected_xy[BLOCK_SIZE];
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
__shared__ float collected_colors[C * BLOCK_SIZE];
Expand Down Expand Up @@ -470,6 +472,9 @@ renderCUDA(
if (range.x + progress < range.y)
{
const int coll_id = point_list[range.y - progress - 1];
const int cur_tiles_touched = tiles_touched[coll_id];
bool cur_use_atomic = cur_tiles_touched <= USE_ATOMIC_THRESHOLD;
collected_use_atomic[block.thread_rank()] = cur_use_atomic;
collected_id[block.thread_rank()] = coll_id;
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
Expand All @@ -478,80 +483,193 @@ renderCUDA(
}
block.sync();

static constexpr int REDUCTION_BATCH_SIZE = 16;
int cur_reduction_batch_idx = 0;
__shared__ int batch_j[REDUCTION_BATCH_SIZE];
__shared__ float batch_dL_dcolors[REDUCTION_BATCH_SIZE][NUM_WARPS][C];
__shared__ float2 batch_dL_dmean2D[REDUCTION_BATCH_SIZE][NUM_WARPS];
__shared__ float4 batch_dL_dconic2D_dopacity[REDUCTION_BATCH_SIZE][NUM_WARPS];

// Iterate over Gaussians
for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
for (int j = 0; j < min(BLOCK_SIZE, toDo); j++)
{
// Keep track of current Gaussian ID. Skip, if this one
// is behind the last contributor for this pixel.
float cur_dL_dcolors[C] = {0};
float2 cur_dL_dmean2D = {0, 0};
float4 cur_dL_dconic2D_dopacity = {0, 0, 0, 0};
const bool use_atomic = collected_use_atomic[j];

contributor--;
if (contributor >= last_contributor)
continue;

// Compute blending values, as before.
const float2 xy = collected_xy[j];
const float2 d = { xy.x - pixf.x, xy.y - pixf.y };
const float4 con_o = collected_conic_opacity[j];
const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
if (power > 0.0f)
continue;

const float G = exp(power);
const float alpha = min(0.99f, con_o.w * G);
if (alpha < 1.0f / 255.0f)
continue;

T = T / (1.f - alpha);
const float dchannel_dcolor = alpha * T;

// Propagate gradients to per-Gaussian colors and keep
// gradients w.r.t. alpha (blending factor for a Gaussian/pixel
// pair).
float dL_dalpha = 0.0f;
const int global_id = collected_id[j];
for (int ch = 0; ch < C; ch++)
{
const float c = collected_colors[ch * BLOCK_SIZE + j];
// Update last color (to be used in the next iteration)
accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
last_color[ch] = c;

const float dL_dchannel = dL_dpixel[ch];
dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
// Update the gradients w.r.t. color of the Gaussian.
// Atomic, since this pixel is just one of potentially
// many that were affected by this Gaussian.
atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
if (inside && contributor < last_contributor) {
// Compute blending values, as before.
const float2 xy = collected_xy[j];
const float2 d = { xy.x - pixf.x, xy.y - pixf.y };
const float4 con_o = collected_conic_opacity[j];
const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;

const float G = exp(power);
const float alpha = min(0.99f, con_o.w * G);
if (power <= 0.0f && alpha >= 1.0f / 255.0f) {
T = T / (1.f - alpha);
const float dchannel_dcolor = alpha * T;

// Propagate gradients to per-Gaussian colors and keep
// gradients w.r.t. alpha (blending factor for a Gaussian/pixel
// pair).
float dL_dalpha = 0.0f;
const int global_id = collected_id[j];
#pragma unroll
for (int ch = 0; ch < C; ch++)
{
const float c = collected_colors[ch * BLOCK_SIZE + j];
// Update last color (to be used in the next iteration)
accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
last_color[ch] = c;

const float dL_dchannel = dL_dpixel[ch];
dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
// Update the gradients w.r.t. color of the Gaussian.
// Atomic, since this pixel is just one of potentially
// many that were affected by this Gaussian.
if (use_atomic) {
atomicAdd(&dL_dcolors[global_id*C + ch], dchannel_dcolor * dL_dchannel);
} else {
cur_dL_dcolors[ch] = dchannel_dcolor * dL_dchannel;
}
}
dL_dalpha *= T;
// Update last alpha (to be used in the next iteration)
last_alpha = alpha;

// Account for fact that alpha also influences how much of
// the background color is added if nothing left to blend
float bg_dot_dpixel = 0;
#pragma unroll
for (int i = 0; i < C; i++)
bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;

// Helpful reusable temporary variables
const float dL_dG = con_o.w * dL_dalpha;
const float gdx = G * d.x;
const float gdy = G * d.y;
const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;

if (use_atomic) {
// Update gradients w.r.t. 2D mean position of the Gaussian
atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
// Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
// Update gradients w.r.t. opacity of the Gaussian
atomicAdd(&dL_dopacity[global_id], G * dL_dalpha);
} else {
cur_dL_dmean2D = {
dL_dG * dG_ddelx * ddelx_dx,
dL_dG * dG_ddely * ddely_dy
};
cur_dL_dconic2D_dopacity = {
-0.5f * gdx * d.x * dL_dG,
-0.5f * gdx * d.y * dL_dG,
G * dL_dalpha,
-0.5f * gdy * d.y * dL_dG
};
}
}
}
dL_dalpha *= T;
// Update last alpha (to be used in the next iteration)
last_alpha = alpha;

// Account for fact that alpha also influences how much of
// the background color is added if nothing left to blend
float bg_dot_dpixel = 0;
for (int i = 0; i < C; i++)
bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;


// Helpful reusable temporary variables
const float dL_dG = con_o.w * dL_dalpha;
const float gdx = G * d.x;
const float gdy = G * d.y;
const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;

// Update gradients w.r.t. 2D mean position of the Gaussian
atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);

// Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
if (!use_atomic) {
// Perform warp-level reduction
#pragma unroll
for (int offset = 32/2; offset > 0; offset /= 2) {
#pragma unroll
for (int ch = 0; ch < C; ch++)
cur_dL_dcolors[ch] += __shfl_down_sync(0xFFFFFFFF, cur_dL_dcolors[ch], offset);
cur_dL_dmean2D.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.x, offset);
cur_dL_dmean2D.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.y, offset);
cur_dL_dconic2D_dopacity.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.x, offset);
cur_dL_dconic2D_dopacity.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.y, offset);
cur_dL_dconic2D_dopacity.z += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.z, offset);
cur_dL_dconic2D_dopacity.w += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.w, offset);
}

// Store the results in shared memory
if (block.thread_rank() % WARP_SIZE == 0)
{
int warp_id = block.thread_rank() / WARP_SIZE;
batch_j[cur_reduction_batch_idx] = j;
#pragma unroll
for (int ch = 0; ch < C; ch++)
batch_dL_dcolors[cur_reduction_batch_idx][warp_id][ch] = cur_dL_dcolors[ch];
batch_dL_dmean2D[cur_reduction_batch_idx][warp_id] = cur_dL_dmean2D;
batch_dL_dconic2D_dopacity[cur_reduction_batch_idx][warp_id] = cur_dL_dconic2D_dopacity;
}
cur_reduction_batch_idx += 1;
}

// Update gradients w.r.t. opacity of the Gaussian
atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
// If this is the last Gaussian in the batch, perform block-level
// reduction and store the results in global memory.
if (cur_reduction_batch_idx == REDUCTION_BATCH_SIZE || (j == min(BLOCK_SIZE, toDo) - 1 && cur_reduction_batch_idx != 0))
{
// Make sure we can perform this reduction with one warp
static_assert(NUM_WARPS <= WARP_SIZE);
// Make sure the number of warps is a power of 2
static_assert((NUM_WARPS & (NUM_WARPS - 1)) == 0);

// Wait for all warps to finish storing
block.sync();

for (int batch_id = block.thread_rank() / WARP_SIZE; batch_id < cur_reduction_batch_idx; batch_id += NUM_WARPS) {
int lane_id = block.thread_rank() % WARP_SIZE;

// Perform warp-level reduction
#pragma unroll
for (int ch = 0; ch < C; ch++)
cur_dL_dcolors[ch] = lane_id < NUM_WARPS ? batch_dL_dcolors[batch_id][lane_id][ch] : 0,
cur_dL_dmean2D = lane_id < NUM_WARPS ? batch_dL_dmean2D[batch_id][lane_id] : float2{0, 0},
cur_dL_dconic2D_dopacity = lane_id < NUM_WARPS ? batch_dL_dconic2D_dopacity[batch_id][lane_id] : float4{0, 0, 0, 0};

#pragma unroll
for (int offset = NUM_WARPS/2; offset > 0; offset /= 2) {
#pragma unroll
for (int ch = 0; ch < C; ch++)
cur_dL_dcolors[ch] += __shfl_down_sync(0xFFFFFFFF, cur_dL_dcolors[ch], offset);
cur_dL_dmean2D.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.x, offset);
cur_dL_dmean2D.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.y, offset);
cur_dL_dconic2D_dopacity.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.x, offset);
cur_dL_dconic2D_dopacity.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.y, offset);
cur_dL_dconic2D_dopacity.z += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.z, offset);
cur_dL_dconic2D_dopacity.w += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.w, offset);
}

// Store the results in global memory
if (lane_id == 0)
{
const int global_id = collected_id[batch_j[batch_id]];
// if (global_id < 0 || global_id >= 208424)
// printf("%d\n", global_id);
#pragma unroll
for (int ch = 0; ch < C; ch++)
atomicAdd(&dL_dcolors[global_id * C + ch], cur_dL_dcolors[ch]);
atomicAdd(&dL_dmean2D[global_id].x, cur_dL_dmean2D.x);
atomicAdd(&dL_dmean2D[global_id].y, cur_dL_dmean2D.y);
atomicAdd(&dL_dconic2D[global_id].x, cur_dL_dconic2D_dopacity.x);
atomicAdd(&dL_dconic2D[global_id].y, cur_dL_dconic2D_dopacity.y);
atomicAdd(&dL_dconic2D[global_id].w, cur_dL_dconic2D_dopacity.w);
atomicAdd(&dL_dopacity[global_id], cur_dL_dconic2D_dopacity.z);
}
}

// Wait for all warps to finish reducing
if (j != min(BLOCK_SIZE, toDo) - 1)
block.sync();

cur_reduction_batch_idx = 0;
}
}
}
}
Expand Down Expand Up @@ -632,6 +750,7 @@ void BACKWARD::render(
const float* colors,
const float* final_Ts,
const uint32_t* n_contrib,
const uint32_t* tiles_touched,
const float* dL_dpixels,
float3* dL_dmean2D,
float4* dL_dconic2D,
Expand All @@ -648,10 +767,11 @@ void BACKWARD::render(
colors,
final_Ts,
n_contrib,
tiles_touched,
dL_dpixels,
dL_dmean2D,
dL_dconic2D,
dL_dopacity,
dL_dcolors
);
}
}
1 change: 1 addition & 0 deletions cuda_rasterizer/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace BACKWARD
const float* colors,
const float* final_Ts,
const uint32_t* n_contrib,
const uint32_t* tiles_touched,
const float* dL_dpixels,
float3* dL_dmean2D,
float4* dL_dconic2D,
Expand Down
1 change: 1 addition & 0 deletions cuda_rasterizer/rasterizer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ void CudaRasterizer::Rasterizer::backward(
color_ptr,
imgState.accum_alpha,
imgState.n_contrib,
geomState.tiles_touched,
dL_dpix,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
Expand Down