-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathdemo.py
85 lines (64 loc) · 2.56 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Demo code for the paper
Gwak et al.,
Weakly supervised 3D Reconstruction with Adversarial Constraint, 3DV 2017
"""
import os
import sys
if (sys.version_info < (3, 0)):
raise Exception("Python 3 required. Please follow the installation instruction on 'https://github.com/jgwak/McRecon'")
import shutil
import numpy as np
from subprocess import call
from PIL import Image
from models import load_model
from lib.config import cfg, cfg_from_list
from lib.data_augmentation import preprocess_img
from lib.solver import Solver
from lib.voxel import voxel2obj
DEFAULT_WEIGHTS = 'output/GANMaskNet/default_model/weights.npy'
def cmd_exists(cmd):
return shutil.which(cmd) is not None
def download_model(fn):
if not os.path.isfile(fn):
# Download the file if doewn't exist
print('Downloading a pretrained model')
call(['curl', 'ftp://cs.stanford.edu/cs/cvgl/GANMaskNet_furniture.npy',
'--create-dirs', '-o', fn])
def load_demo_images():
ims = []
for i in range(3):
im = Image.open('imgs/%02d.png' % i)
im = preprocess_img(im, train=False)[0]
ims.append([np.array(im).transpose((2, 0, 1)).astype(np.float32)])
return np.array(ims)
def main():
'''Main demo function'''
# Save prediction into a file named 'prediction.obj' or the given argument
pred_file_name = sys.argv[1] if len(sys.argv) > 1 else 'prediction.obj'
# load images
demo_imgs = load_demo_images()
# Download and load pretrained weights
download_model(DEFAULT_WEIGHTS)
# Use the default network model
NetClass = load_model('GANMaskNet')
# Define a network and a solver. Solver provides a wrapper for the test function.
net = NetClass(compute_grad=False) # instantiate a network
net.load(DEFAULT_WEIGHTS) # load downloaded weights
solver = Solver(net) # instantiate a solver
# Run the network
voxel_prediction, _ = solver.test_output(demo_imgs)
# Save the prediction to an OBJ file (mesh file).
voxel2obj(pred_file_name, voxel_prediction[0, :, 1, :, :] > cfg.TEST.VOXEL_THRESH)
# Use meshlab or other mesh viewers to visualize the prediction.
# For Ubuntu>=14.04, you can install meshlab using
# `sudo apt-get install meshlab`
if cmd_exists('meshlab'):
call(['meshlab', pred_file_name])
else:
print('Meshlab not found: please use visualization of your choice to view %s' %
pred_file_name)
if __name__ == '__main__':
# Set the batch size to 1
cfg_from_list(['CONST.BATCH_SIZE', 1])
main()