-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
29 lines (23 loc) · 1.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from utils.config import get_dataset, get_args
from utils.post_process import post_process
from graph.construction import mask_graph_construction
from graph.iterative_clustering import iterative_clustering
from tqdm import tqdm
import os
def main(args):
dataset = get_dataset(args)
scene_points = dataset.get_scene_points()
frame_list = dataset.get_frame_list(args.step)
# if os.path.exists(os.path.join(dataset.object_dict_dir, args.config, f'object_dict.npy')):
# return
with torch.no_grad():
nodes, observer_num_thresholds, mask_point_clouds, point_frame_matrix = mask_graph_construction(args, scene_points, frame_list, dataset)
object_list = iterative_clustering(nodes, observer_num_thresholds, args.view_consensus_threshold, args.debug)
post_process(dataset, object_list, mask_point_clouds, scene_points, point_frame_matrix, frame_list, args)
if __name__ == '__main__':
args = get_args()
seq_name_list = args.seq_name_list.split('+')
for seq_name in tqdm(seq_name_list):
args.seq_name = seq_name
main(args)