-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
120 lines (94 loc) · 3.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
'''The main file for training a model
Usage:
'''
import os
import sys
import torch
import logging
from config.default import get_cfg_defaults
import argparse
import importlib
from utils.utils import mkdirs
import random
import numpy as np
torch.backends.cudnn.benchmark = True
def default_argument_parser():
"""
Create arg parser
"""
parser = argparse.ArgumentParser("Args parser for training")
parser.add_argument("--config", default="", metavar="FILE", help="path to config file")
parser.add_argument("--debug", action="store_true", help="enable debug mode")
parser.add_argument("--trainer", default="", help="The trainer")
#
# parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument(
"opts",
help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. "
"See config references at "
"https://detectron2.readthedocs.io/modules/config.html#config-references",
default=None,
nargs=argparse.REMAINDER,
)
return parser
def init_seed(config):
if config.SEED is None:
seed = random.randrange(0, 10000)
logging.info("No random seed is provided. Generate a seed randomly: {}".format(seed))
config.SEED = seed
else:
logging.info("Using random seed from the config: {}".format(config.SEED))
torch.manual_seed(config.SEED)
torch.cuda.manual_seed_all(config.SEED)
torch.cuda.manual_seed(config.SEED)
random.seed(config.SEED)
np.random.seed(config.SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
return config
def setup(args):
"""
Perform some basic common setups at the beginning of a job, including:
"""
cfg = get_cfg_defaults()
assert args.trainer != None, "pls input the argument '--trainer'"
trainer_lib = importlib.import_module('models.'+args.trainer)
if 'custom_cfg' in trainer_lib.__all__:
cfg.merge_from_other_cfg(trainer_lib.custom_cfg)
if args.config:
cfg.merge_from_file(args.config)
if args.opts:
cfg.merge_from_list(args.opts)
# cfg.DEBUG = args.debug
cfg.OUTPUT_DIR = mkdirs(cfg.OUTPUT_DIR)
output_dir = cfg.OUTPUT_DIR
logging.basicConfig(level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(filename)s-%(funcName)s-%(lineno)d:%(message)s",
datefmt='%a-%d %b %Y %H:%M:%S',
handlers=[logging.FileHandler(os.path.join(output_dir, 'train.log'), 'a', 'utf-8'),
logging.StreamHandler()]
)
logging.info("Config:\n" + str(cfg))
logging.info("Command line: python " + ' '.join(sys.argv))
cfg = init_seed(cfg)
config_saved_path = os.path.join(output_dir, "train_config.yaml")
with open(config_saved_path, "w") as f:
f.write(cfg.dump())
logging.info("Full config saved to {}".format(config_saved_path))
cfg.freeze()
return trainer_lib, cfg
def main(trainer_lib, config):
# Ensure directories are setup
# We feed network_config via unparsed_args so that it is flexible to deal with different models
trainer = trainer_lib.Trainer(config)
trainer.set_trainMode(True)
trainer.set_model()
trainer.set_dataloader()
trainer.set_augment()
trainer.set_iteration()
trainer.train()
if __name__ == '__main__':
parser = default_argument_parser()
args = parser.parse_args()
trainer_lib, config = setup(args)
main(trainer_lib, config)