-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathdemo_nicp.py
56 lines (49 loc) · 2.42 KB
/
demo_nicp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
# All rights reserved.
# This file is part of the pytorch-nicp,
# and is released under the "MIT License Agreement". Please see the LICENSE
# file that should have been included as part of this package.
import torch
import io3d
import render
import numpy as np
import json
from utils import normalize_mesh, normalize_pcl
from landmark import get_mesh_landmark
from bfm_model import load_bfm_model
from nicp import non_rigid_icp_mesh2pcl, non_rigid_icp_mesh2mesh
# demo for registering mesh
# estimate landmark for target meshes
# the face must face toward z axis
# the mesh or point cloud must be normalized with normalize_mesh/normalize_pcl function before feed into the nicp process
device = torch.device('cuda:0')
meshes = io3d.load_obj_as_mesh('./test_data/pjanic.obj', device = device)
with torch.no_grad():
norm_meshes, norm_param = normalize_mesh(meshes)
dummy_render = render.create_dummy_render([1, 0, 0], device = device)
target_lm_index, lm_mask = get_mesh_landmark(norm_meshes, dummy_render)
bfm_meshes, bfm_lm_index = load_bfm_model(torch.device('cuda:0'))
lm_mask = torch.all(lm_mask, dim = 0)
bfm_lm_index_m = bfm_lm_index[:, lm_mask]
target_lm_index_m = target_lm_index[:, lm_mask]
fine_config = json.load(open('config/fine_grain.json'))
registered_mesh = non_rigid_icp_mesh2mesh(bfm_meshes, norm_meshes, bfm_lm_index_m, target_lm_index_m, fine_config)
io3d.save_meshes_as_objs(['final.obj'], registered_mesh, save_textures = False)
# demo for registering point cloud
device = torch.device('cuda:0')
pcls = io3d.load_ply_as_pointcloud('./test_data/test2.ply', device = device)
norm_pcls, norm_param = normalize_pcl(pcls)
pcl_lm_file = open('./test_data/test2_lm.txt')
lm_list = []
for line in pcl_lm_file:
line = int(line.strip())
lm_list.append(line)
target_lm_index = torch.from_numpy(np.array(lm_list)).to(device)
lm_mask = (target_lm_index >= 0)
target_lm_index = target_lm_index.unsqueeze(0)
bfm_meshes, bfm_lm_index = load_bfm_model(torch.device('cuda:0'))
bfm_lm_index_m = bfm_lm_index[:, lm_mask]
target_lm_index_m = target_lm_index[:, lm_mask]
coarse_config = json.load(open('config/coarse_grain.json'))
registered_mesh = non_rigid_icp_mesh2pcl(bfm_meshes, norm_pcls, bfm_lm_index_m, target_lm_index_m, coarse_config)
io3d.save_meshes_as_objs(['test_data/final2.obj'], registered_mesh, save_textures = False)