-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmerge_submodules.py
105 lines (80 loc) · 4.25 KB
/
merge_submodules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from argparse import Namespace
from pathlib import Path
import sys
import torch
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from mega_nerf.models.mega_nerf import MegaNeRF
from mega_nerf.models.mega_nerf_container import MegaNeRFContainer
from mega_nerf.models.model_utils import get_nerf, get_bg_nerf
from mega_nerf.opts import get_opts_base
def _get_merge_opts() -> Namespace:
parser = get_opts_base()
parser.add_argument('--ckpt_prefix', type=str, required=True)
parser.add_argument('--centroid_path', type=str, required=True)
parser.add_argument('--output', type=str, required=True)
parser.add_argument('--ckpt_iteration', type=str, required=True)
return parser.parse_known_args()[0]
@torch.inference_mode()
def main(hparams: Namespace) -> None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt_prefix = Path(hparams.ckpt_prefix)
centroid_metadata = torch.load(hparams.centroid_path, map_location='cpu')
centroids = centroid_metadata['centroids']
sub_modules = []
bg_sub_modules = []
for i in range(len(centroids)):
centroid_path = ckpt_prefix.parent / '{}{}'.format(ckpt_prefix.name, i)
if not centroid_path.exists():
raise Exception('{} not found'.format(centroid_path))
version_dirs = sorted([int(x.name) for x in list(centroid_path.iterdir())], reverse=True)
for version_dir in version_dirs:
checkpoint = centroid_path / str(version_dir) / 'models' / '{}.pt'.format(hparams.ckpt_iteration)
if checkpoint.exists():
break
if not checkpoint.exists():
raise Exception('Could not find check point {}.pt in {}'.format(hparams.ckpt_iteration, checkpoint))
loaded = torch.load(checkpoint, map_location='cpu')
consume_prefix_in_state_dict_if_present(loaded['model_state_dict'], prefix='module.')
if hparams.appearance_dim > 0:
appearance_count = len(loaded['model_state_dict']['embedding_a.weight'])
else:
appearance_count = 0
sub_module = get_nerf(hparams, appearance_count)
model_dict = sub_module.state_dict()
model_dict.update(loaded['model_state_dict'])
sub_module.load_state_dict(model_dict)
sub_modules.append(sub_module)
if 'bg_model_state_dict' in loaded:
consume_prefix_in_state_dict_if_present(loaded['bg_model_state_dict'], prefix='module.')
sub_module = get_bg_nerf(hparams, appearance_count)
model_dict = sub_module.state_dict()
model_dict.update(loaded['bg_model_state_dict'])
sub_module.load_state_dict(model_dict)
bg_sub_modules.append(sub_module)
container = MegaNeRFContainer(sub_modules, bg_sub_modules, centroids,
torch.IntTensor(centroid_metadata['grid_dim']),
centroid_metadata['min_position'],
centroid_metadata['max_position'],
hparams.pos_dir_dim > 0,
hparams.appearance_dim > 0,
centroid_metadata['cluster_2d'])
if not Path(hparams.output).parent.exists():
Path(hparams.output).parent.mkdir()
torch.jit.save(torch.jit.script(container.eval()), hparams.output)
container = torch.jit.load(hparams.output, map_location='cpu')
# Test container
nerf = MegaNeRF([getattr(container, 'sub_module_{}'.format(i)) for i in range(len(container.centroids))],
container.centroids, hparams.boundary_margin, False, container.cluster_2d).to(device)
width = 3
if hparams.pos_dir_dim > 0:
width += 3
if hparams.appearance_dim > 0:
width += 1
print('fg test eval: {}'.format(nerf(torch.ones(1, width, device=device))))
if len(bg_sub_modules) > 0:
bg_nerf = MegaNeRF([getattr(container, 'bg_sub_module_{}'.format(i)) for i in range(len(container.centroids))],
container.centroids, hparams.boundary_margin, True, container.cluster_2d).to(device)
width += 4
print('bg test eval: {}'.format(bg_nerf(torch.ones(1, width, device=device))))
if __name__ == '__main__':
main(_get_merge_opts())