Skip to content

Commit

Permalink
test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Aug 6, 2024
1 parent bc048e8 commit 7bd65db
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 49 deletions.
2 changes: 1 addition & 1 deletion examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def depth_to_normal(
height,
render_mode="RGB+ED",
sh_degree=sh_degree,
accurate_depth=False,
# accurate_depth=False,
depth_clips=torch.stack(
[
torch.full((C, height, width), 0.30, device=device),
Expand Down
5 changes: 1 addition & 4 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,16 +410,12 @@ def accumulate(
A = c[:, -1] # [M,] conics22
B = torch.einsum("...i,...i->...", c[:, [2, 4, 5]], (mu - o)) # [M]
D = B / A # [M]
print("c", c[:, -1].min(), c[:, -1].max())
# print("D", D)
# print("mu", mu)

# clip the gaussian integral range
if depth_clips is not None:
beta = torch.sqrt(A / 2)
z0, z1 = torch.unbind(depth_clips[camera_ids, pixel_ids_y, pixel_ids_x], dim=-1)
ratio = 0.5 * (torch.erf(beta * (z1 - D)) - torch.erf(beta * (z0 - D)))
print("ratio", ratio.min(), ratio.max())
# alphas = alphas * ratio
alphas = 1.0 - (1.0 - alphas) ** ratio

Expand Down Expand Up @@ -515,6 +511,7 @@ def _rasterize_to_pixels(
tile_size,
isect_offsets,
flatten_ids,
depth_clips=depth_clips,
) # [M], [M]
if len(gs_ids) == 0:
break
Expand Down
10 changes: 10 additions & 0 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def rasterize_to_pixels(
backgrounds: Optional[Tensor] = None, # [C, channels]
packed: bool = False,
absgrad: bool = False,
depth_clips: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Rasterizes Gaussians to pixels.
Expand Down Expand Up @@ -491,6 +492,7 @@ def rasterize_to_pixels(
tile_size,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
depth_clips.contiguous() if depth_clips is not None else None,
absgrad,
)

Expand All @@ -512,6 +514,7 @@ def rasterize_to_indices_in_range(
tile_size: int,
isect_offsets: Tensor, # [C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
depth_clips: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Rasterizes a batch of Gaussians to images but only returns the indices.
Expand Down Expand Up @@ -568,6 +571,7 @@ def rasterize_to_indices_in_range(
tile_size,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
depth_clips.contiguous() if depth_clips is not None else None,
)
out_pixel_ids = out_indices % (image_width * image_height)
out_camera_ids = out_indices // (image_width * image_height)
Expand Down Expand Up @@ -826,6 +830,7 @@ def forward(
tile_size: int,
isect_offsets: Tensor, # [C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
depth_clips: Optional[Tensor],
absgrad: bool,
) -> Tuple[Tensor, Tensor]:
render_colors, render_alphas, last_ids = _make_lazy_cuda_func(
Expand All @@ -841,6 +846,7 @@ def forward(
tile_size,
isect_offsets,
flatten_ids,
depth_clips,
)

ctx.save_for_backward(
Expand All @@ -851,6 +857,7 @@ def forward(
backgrounds,
isect_offsets,
flatten_ids,
depth_clips,
render_alphas,
last_ids,
)
Expand All @@ -877,6 +884,7 @@ def backward(
backgrounds,
isect_offsets,
flatten_ids,
depth_clips,
render_alphas,
last_ids,
) = ctx.saved_tensors
Expand All @@ -902,6 +910,7 @@ def backward(
tile_size,
isect_offsets,
flatten_ids,
depth_clips,
render_alphas,
last_ids,
v_render_colors.contiguous(),
Expand Down Expand Up @@ -931,6 +940,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down
25 changes: 15 additions & 10 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <tuple>

#define N_THREADS 256
#define INV_SQRT_PI 0.564189583547756286948079451560772585844050629329f

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
Expand Down Expand Up @@ -122,8 +123,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids, // [n_isects]
const at::optional<torch::Tensor> &depth_clips // [C, H, W, 2]
);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -137,8 +139,9 @@ rasterize_to_pixels_bwd_tensor(
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids, // [n_isects]
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids, // [n_isects]
const at::optional<torch::Tensor> &depth_clips, // [C, H, W, 2]
// forward outputs
const torch::Tensor &render_alphas, // [C, image_height, image_width, 1]
const torch::Tensor &last_ids, // [C, image_height, image_width]
Expand All @@ -149,17 +152,19 @@ rasterize_to_pixels_bwd_tensor(
bool absgrad);

std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
const uint32_t range_start, const uint32_t range_end, // iteration steps
const uint32_t range_start,
const uint32_t range_end, // iteration steps
const torch::Tensor transmittances, // [C, image_height, image_width]
// Gaussian parameters
const torch::Tensor &means2d, // [C, N, 2]
const torch::Tensor &conics, // [C, N, 3]
const torch::Tensor &opacities, // [N]
const torch::Tensor &means2d, // [C, N, 3]
const torch::Tensor &conics, // [C, N, 6]
const torch::Tensor &opacities, // [C, N]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids, // [n_isects]
const at::optional<torch::Tensor> &depth_clips // [C, H, W, 2]
);

torch::Tensor compute_sh_fwd_tensor(const uint32_t degrees_to_use,
Expand Down
43 changes: 40 additions & 3 deletions gsplat/cuda/csrc/rasterize_to_indices_in_range.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ __global__ void rasterize_to_indices_in_range_kernel(
const uint32_t tile_width, const uint32_t tile_height,
const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width]
const int32_t *__restrict__ flatten_ids, // [n_isects]
const T *__restrict__ transmittances, // [C, image_height, image_width]
const T *__restrict__ depth_clips, // [C, image_height, image_width, 2] optional
const T *__restrict__ transmittances, // [C, image_height, image_width]
const int32_t *__restrict__ chunk_starts, // [C, image_height, image_width]
int32_t *__restrict__ chunk_cnts, // [C, image_height, image_width]
int64_t *__restrict__ gaussian_ids, // [n_elems]
Expand Down Expand Up @@ -50,6 +51,17 @@ __global__ void rasterize_to_indices_in_range_kernel(
bool inside = (i < image_height && j < image_width);
bool done = !inside;

// clip the range of the gaussian integral
T z0, z1;
if (depth_clips != nullptr) {
depth_clips += camera_id * image_height * image_width * 2;
z0 = depth_clips[pix_id * 2 + 0]; // min
z1 = depth_clips[pix_id * 2 + 1]; // max
if (z0 >= z1 || z0 < 0 || z1 < 0) {
done = true;
}
}

bool first_pass = chunk_starts == nullptr;
int32_t base;
if (!first_pass && inside) {
Expand Down Expand Up @@ -145,6 +157,23 @@ __global__ void rasterize_to_indices_in_range_kernel(
continue;
}

// clip the gaussian integral range
if (depth_clips != nullptr) {
T A = conic22;
T B = conic02 * (mean2d.x - px) + conic12 * (mean2d.y - py) +
conic22 * mean2d.z;
T D = B / A; // the accurate depth

T beta = sqrtf(A * 0.5f);
T ratio = 0.5f * (erff(beta * (z1 - D)) - erff(beta * (z0 - D)));

// alpha *= ratio;
alpha = 1.0f - powf(1.0f - alpha, ratio); // this produces smaller diff
if (alpha < 1.f / 255.f) {
continue;
}
}

next_trans = trans * (1.0f - alpha);
if (next_trans <= 1e-4) { // this pixel is done: exclusive
done = true;
Expand Down Expand Up @@ -184,15 +213,19 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids, // [n_isects]
const at::optional<torch::Tensor> &depth_clips // [C, H, W, 2]
) {
DEVICE_GUARD(means2d);
CHECK_INPUT(means2d);
CHECK_INPUT(conics);
CHECK_INPUT(opacities);
CHECK_INPUT(tile_offsets);
CHECK_INPUT(flatten_ids);
if (depth_clips.has_value()) {
CHECK_INPUT(depth_clips.value());
}

uint32_t C = means2d.size(0); // number of cameras
uint32_t N = means2d.size(1); // number of gaussians
Expand Down Expand Up @@ -229,6 +262,8 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
conics.data_ptr<float>(), opacities.data_ptr<float>(), image_width,
image_height, tile_size, tile_width, tile_height,
tile_offsets.data_ptr<int32_t>(), flatten_ids.data_ptr<int32_t>(),
depth_clips.has_value() ? depth_clips.value().data_ptr<float>()
: nullptr,
transmittances.data_ptr<float>(), nullptr,
chunk_cnts.data_ptr<int32_t>(), nullptr, nullptr);

Expand All @@ -252,6 +287,8 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
conics.data_ptr<float>(), opacities.data_ptr<float>(), image_width,
image_height, tile_size, tile_width, tile_height,
tile_offsets.data_ptr<int32_t>(), flatten_ids.data_ptr<int32_t>(),
depth_clips.has_value() ? depth_clips.value().data_ptr<float>()
: nullptr,
transmittances.data_ptr<float>(), chunk_starts.data_ptr<int32_t>(),
nullptr, gaussian_ids.data_ptr<int64_t>(),
pixel_ids.data_ptr<int64_t>());
Expand Down
Loading

0 comments on commit 7bd65db

Please sign in to comment.