Skip to content

Commit

Permalink
Add np.random.choice fallback for many Gaussians exceeding torch.mult…
Browse files Browse the repository at this point in the history
…inomial limits (#338)

* use np.random.choice for >2**24 multinomial

* format
  • Loading branch information
soskek authored Aug 21, 2024
1 parent 435305e commit af10217
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from typing import Callable, Dict, List, Union

import torch
Expand All @@ -9,6 +10,40 @@
from gsplat.utils import normalized_quat_to_rotmat


@torch.no_grad()
def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Tensor:
"""Sample from a distribution using torch.multinomial or numpy.random.choice.
This function adaptively chooses between `torch.multinomial` and `numpy.random.choice`
based on the number of elements in `weights`. If the number of elements exceeds
the torch.multinomial limit (2^24), it falls back to using `numpy.random.choice`.
Args:
weights (Tensor): A 1D tensor of weights for each element.
n (int): The number of samples to draw.
replacement (bool): Whether to sample with replacement. Default is True.
Returns:
Tensor: A 1D tensor of sampled indices.
"""
num_elements = weights.size(0)

if num_elements <= 2**24:
# Use torch.multinomial for elements within the limit
return torch.multinomial(weights, n, replacement=replacement)
else:
# Fallback to numpy.random.choice for larger element spaces
weights = weights / weights.sum()
weights_np = weights.detach().cpu().numpy()
sampled_idxs_np = np.random.choice(
num_elements, size=n, p=weights_np, replace=replacement
)
sampled_idxs = torch.from_numpy(sampled_idxs_np)

# Return the sampled indices on the original device
return sampled_idxs.to(weights.device)


@torch.no_grad()
def _update_param_with_optimizer(
param_fn: Callable[[str, Tensor], Tensor],
Expand Down Expand Up @@ -226,7 +261,7 @@ def relocate(
# Sample for new GSs
eps = torch.finfo(torch.float32).eps
probs = opacities[alive_indices].flatten() # ensure its shape is [N,]
sampled_idxs = torch.multinomial(probs, n, replacement=True)
sampled_idxs = _multinomial_sample(probs, n, replacement=True)
sampled_idxs = alive_indices[sampled_idxs]
new_opacities, new_scales = compute_relocation(
opacities=opacities[sampled_idxs],
Expand Down Expand Up @@ -269,7 +304,7 @@ def sample_add(

eps = torch.finfo(torch.float32).eps
probs = opacities.flatten()
sampled_idxs = torch.multinomial(probs, n, replacement=True)
sampled_idxs = _multinomial_sample(probs, n, replacement=True)
new_opacities, new_scales = compute_relocation(
opacities=opacities[sampled_idxs],
scales=torch.exp(params["scales"])[sampled_idxs],
Expand Down

0 comments on commit af10217

Please sign in to comment.