-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathtrain.py
93 lines (72 loc) · 2.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Train base models to later be pruned"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import argparse
import random
import numpy as np
from models import get_model
from utils import *
from tqdm import tqdm
################################################################## ARGUMENT PARSING
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 Training")
parser.add_argument(
"--model", default="resnet18", help="resnet9/18/34/50, wrn_40_2/_16_2/_40_1"
)
parser.add_argument("--data_loc", default="/disk/scratch/datasets/cifar", type=str)
parser.add_argument("--checkpoint", default=None, type=str)
parser.add_argument("--n_gpus", default=0, type=int, help="Number of GPUs to use")
### training specific args
parser.add_argument("--epochs", default=200, type=int)
parser.add_argument("--lr", default=0.1)
parser.add_argument(
"--lr_decay_ratio", default=0.2, type=float, help="learning rate decay"
)
parser.add_argument("--weight_decay", default=0.0005, type=float)
### reproducibility
parser.add_argument("--seed", default=1, type=int)
args = parser.parse_args()
print(args.data_loc)
################################################################## REPRODUCIBILITY
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
################################################################## MODEL LOADING
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
select_devices(num_gpus_to_use=args.n_gpus)
model = get_model(args.model)
if torch.cuda.is_available():
model = model.cuda()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
if args.checkpoint is None:
args.checkpoint = args.model
################################################################## TRAINING HYPERPARAMETERS
trainloader, testloader = get_cifar_loaders(args.data_loc)
optimizer = optim.SGD(
[w for name, w in model.named_parameters() if not "mask" in name],
lr=args.lr,
momentum=0.9,
weight_decay=args.weight_decay,
)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1e-10)
criterion = nn.CrossEntropyLoss()
################################################################## ACTUAL TRAINING
error_history = []
for epoch in tqdm(range(args.epochs)):
train(model, trainloader, criterion, optimizer)
validate(
model,
epoch,
testloader,
criterion,
checkpoint=args.checkpoint if epoch != 2 else args.checkpoint + "_init",
seed=args.seed,
)
scheduler.step()