forked from xhp-hust-2018-2011/SS-DCNet
-
Notifications
You must be signed in to change notification settings - Fork 2
/
all_main.py
46 lines (42 loc) · 1.6 KB
/
all_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import argparse
from main_process import main
from IOtools import get_config_str
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Dataset_setting')
parser.add_argument('--dataset', default='SHB', help='choose dataset, SHA,SHB,QNRF')
args = parser.parse_args()
dataset_idxlist = {'SHA':0,'SHB':1,'QNRF':2}
dataset_list = ['SH_partA','SH_partB','UCF-QNRF_ECCV18']
dataset_max = [[22],
[7],
[8]]
dataset_choose = [dataset_idxlist[args.dataset] ]
for di in dataset_choose:
opt = dict()
opt['dataset'] = dataset_list[di]
opt['max_list'] = dataset_max[di]
# step1: Create root path for dataset
opt['root_dir'] = os.path.join(r'data',opt['dataset'])
opt['num_workers'] = 0
opt['IF_savemem_train'] = False
opt['IF_savemem_test'] = False
# -- test setting
opt['test_batch_size'] = 1
# --Network settinng
opt['psize'],opt['pstride'] = 64,64
opt['div_times'] = 2
# -- parse class to count setting
parse_method_dict = {0:'maxp'}
opt['parse_method'] = parse_method_dict[0]
#step2: set the max number and partition method
opt['max_num'] = opt['max_list'][0]
opt['partition'] = 'two_linear'
opt['step'] = 0.5
opt['cuda'] = False
# here create model path
opt['model_path'] = os.path.join('model',args.dataset)
main(opt)