Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added meshing of density field using marching cubes #11

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion configs/lego.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
expname = blender_paper_lego
basedir = ./logs
datadir = ./data/nerf_synthetic/lego
datadir = /home/ubuntu/aman/datasets/lego
dataset_type = blender

no_batching = True
Expand All @@ -17,3 +17,5 @@ precrop_iters = 500
precrop_frac = 0.5

half_res = True

mesh_res = 256
122 changes: 122 additions & 0 deletions mesh_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
import numpy as np
import torch
import plyfile
import skimage.measure
from tqdm import tqdm
import yaml
import os.path as osp
import skimage
import time

def convert_sigma_samples_to_ply(
input_3d_sigma_array: np.ndarray,
voxel_grid_origin,
volume_size,
ply_filename_out,
level=5.0,
offset=None,
scale=None,):
"""
Convert density samples to .ply
:param input_3d_sdf_array: a float array of shape (n,n,n)
:voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
:volume_size: a list of three floats
:ply_filename_out: string, path of the filename to save to
This function adapted from: https://github.com/RobotLocomotion/spartan
"""
start_time = time.time()

verts, faces, normals, values = skimage.measure.marching_cubes(
input_3d_sigma_array, level=level, spacing=volume_size
)

# transform from voxel coordinates to camera coordinates
# note x and y are flipped in the output of marching_cubes
mesh_points = np.zeros_like(verts)
mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]

# apply additional offset and scale
if scale is not None:
mesh_points = mesh_points / scale
if offset is not None:
mesh_points = mesh_points - offset

# try writing to the ply file

# mesh_points = np.matmul(mesh_points, np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]))
# mesh_points = np.matmul(mesh_points, np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]))


num_verts = verts.shape[0]
num_faces = faces.shape[0]

verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])

for i in range(0, num_verts):
verts_tuple[i] = tuple(mesh_points[i, :])

faces_building = []
for i in range(0, num_faces):
faces_building.append(((faces[i, :].tolist(),)))
faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
el_faces = plyfile.PlyElement.describe(faces_tuple, "face")

ply_data = plyfile.PlyData([el_verts, el_faces])
print("saving mesh to %s" % str(ply_filename_out))
ply_data.write(ply_filename_out)

print(
"converting to ply format and writing to file took {} s".format(
time.time() - start_time
)
)

def generate_and_write_mesh(bounding_box, num_pts, levels, chunk, device, ply_root, **render_kwargs):
"""
Generate density grid for marching cubes
:bounding_box: bounding box for meshing
:num_pts: Number of grid elements on each axis
:levels: list of levels to write meshes for
:ply_root: string, path of the folder to save meshes to
"""

near = render_kwargs['near']
bb_min = (*(bounding_box[0] + near).cpu().numpy(),)
bb_max = (*(bounding_box[1] - near).cpu().numpy(),)

x_vals = torch.tensor(np.linspace(bb_min[0], bb_max[0], num_pts))
y_vals = torch.tensor(np.linspace(bb_min[1], bb_max[1], num_pts))
z_vals = torch.tensor(np.linspace(bb_min[2], bb_max[2], num_pts))

xs, ys, zs = torch.meshgrid(x_vals, y_vals, z_vals, indexing = 'ij')
coords = torch.stack((xs, ys, zs), dim = -1)

coords = coords.view(1, -1, 3).type(torch.FloatTensor).to(device)
dummy_viewdirs = torch.tensor([0, 0, 1]).view(-1, 3).type(torch.FloatTensor).to(device)

nerf_model = render_kwargs['network_fine']
radiance_field = render_kwargs['network_query_fn']

chunk_outs = []

for k in tqdm(range(coords.shape[1] // chunk), desc = "Retrieving densities at grid points"):
chunk_out = radiance_field(coords[:, k * chunk: (k + 1) * chunk, :], dummy_viewdirs, nerf_model)
chunk_outs.append(chunk_out.detach().cpu().numpy()[:, :, -1])

if not coords.shape[1] % chunk == 0:
chunk_out = radiance_field(coords[:, (k+1) * chunk: , :], dummy_viewdirs, nerf_model)
chunk_outs.append(chunk_out.detach().cpu().numpy()[:, :, -1])

input_sigma_arr = np.concatenate(chunk_outs, axis = -1).reshape(num_pts, num_pts, num_pts)

for level in levels:
try:
sizes = (abs(bounding_box[1] - bounding_box[0]).cpu()).tolist()
convert_sigma_samples_to_ply(input_sigma_arr, list(bb_min), sizes, osp.join(ply_root, f"test_mesh_{level}.ply"), level = level)
except ValueError:
print(f"Density field does not seem to have an isosurface at level {level} yet")
30 changes: 30 additions & 0 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from optimizer import MultiOptimizer
from radam import RAdam
from loss import sigma_sparsity_loss, total_variation_loss
from mesh_utils import generate_and_write_mesh

from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
Expand Down Expand Up @@ -540,6 +541,12 @@ def config_parser():
parser.add_argument("--render_factor", type=int, default=0,
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')

# mesh options
parser.add_argument("--mesh_only", action='store_true',
help='do not optimize, reload weights and generate mesh')
parser.add_argument("--mesh_res", type=int, default=256,
help='resolution of grid for marching cubes')

# training options
parser.add_argument("--precrop_iters", type=int, default=0,
help='number of steps to train on central crops')
Expand Down Expand Up @@ -585,6 +592,8 @@ def config_parser():
help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=5000,
help='frequency of render_poses video saving')
parser.add_argument("--i_mesh", type=int, default=1000,
help='frequency of mesh saving')

parser.add_argument("--finest_res", type=int, default=512,
help='finest resolultion for hashed embedding')
Expand Down Expand Up @@ -757,6 +766,16 @@ def train():

return

if args.mesh_only:
levels = [0, 5, 10, 15, 20]
print(f"Generating mesh at levels {levels}")
num_pts = args.mesh_res
root_path = os.path.join(basedir, expname, 'test')
os.makedirs(root_path, exist_ok=True)
generate_and_write_mesh(bounding_box, num_pts, levels, args.chunk, device, root_path, **render_kwargs_train)
print('Done, saving mesh at ', root_path)
return

# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
Expand Down Expand Up @@ -930,6 +949,17 @@ def train():
# render_kwargs_test['c2w_staticcam'] = None
# imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

if i%args.i_mesh==0 and i > 0:
levels = [5, 10, 20]
print(f"Generating mesh at levels {levels}")
num_pts = args.mesh_res
root_path = os.path.join(basedir, expname, 'train')
os.makedirs(root_path, exist_ok=True)

with torch.no_grad():
generate_and_write_mesh(bounding_box, num_pts, levels, args.chunk, device, root_path, **render_kwargs_train)
print('Done, saving mesh at ', root_path)

if i%args.i_testset==0 and i > 0:
testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
Expand Down