-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
93 lines (77 loc) · 2.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import torch
from apex import amp
from utils.common import synchronize
def main():
parser = argparse.ArgumentParser(description="Distributed training")
# TODO: add necessarily arguments here
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
args.distributed = num_gpus > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend='nccl',
init_method='env://',
)
synchronize()
# TODO: define model structure
"""About model
model = build_model()
device = torch.device("cuda")
model.to(device)
"""
# TODO: define optimizer and lr_scheduler
"""About optimizer and scheduler
optimizer = make_optimizer()
scheduler = make_lr_scheduler()
"""
# TODO: whether use mixed-precision training
"""mixed-precision training, powed by apex
use_mixed_precision = True
amp_opt_level = 'o1' if use_mixed_precision else 'o0'
model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)
"""
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
# this should be removed if update BatchNorm stats
broadcast_buffers=False,
)
# save file flag
save_to_disk = get_rank() == 0
# TODO: dataset
"""define datasets
dataset = build_dataset()
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
"""
# TODO: dataloader
"""
# define collator: BatchCollator
# define num_workers
# two ways: 1. define batch_size; 2. define batch_sampler
# 1.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=BatchCollator(),
)
# 2.
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=BtachCollator(),
)
"""
# TODO: train and validatation
"""
do_train()
inference()
"""
if __name__ == "__main__":
main()