-
Notifications
You must be signed in to change notification settings - Fork 0
/
Worker.py
137 lines (108 loc) · 4.48 KB
/
Worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import utils
from model import Network
from torch.autograd import Variable
import torchvision
import torch.backends.cudnn as cudnn
import numpy as np
import json
from weight_hist import benford_r2_model
from time import time
class Worker(object):
def __init__(self, actions_p, actions_log_p, actions_index, args, device):
self.actions_p = actions_p
self.actions_log_p = actions_log_p
self.actions_index = actions_index
self.genotype = utils.parse_actions_index(actions_index)
self.args = args
self.device = device
self.params_size = None
self.acc = None
self.memory_stack = []
def get_acc(worker):
torch.manual_seed(worker.args.seed)
np.random.seed(worker.args.seed)
if torch.cuda.is_available():
device = torch.device(worker.device)
cudnn.benchmark = True
cudnn.enable = True
torch.cuda.manual_seed(worker.args.seed)
else:
device = torch.device('cpu')
train_transform, valid_transform = utils._data_transforms_cifar10(worker.args)
train_data = torchvision.datasets.CIFAR10(root=worker.args.data, train=True,
transform=train_transform,
download=True)
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(worker.args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=worker.args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=False, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=worker.args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=False, num_workers=2)
criterion = nn.CrossEntropyLoss()
model = Network(worker.genotype).to(device)
worker.params_size = utils.count_params(model)
optimizer = torch.optim.SGD(model.parameters(),
worker.args.model_lr,
momentum=worker.args.model_momentum,
weight_decay=worker.args.model_weight_decay)
for model_epoch in range(worker.args.model_epochs+500):
train_loss, train_acc = train(model, train_queue, criterion, optimizer, device)
print('train loss {:.4f} acc {:.4f}'.format(train_loss, train_acc))
valid_loss, valid_acc = infer(model, valid_queue, criterion, device)
print('valid loss {:.4f} acc {:.4f}'.format(valid_loss, valid_acc))
# callback
self.memory_stack.append(valid_acc)
if len(self.memory_stack) > 3:
if self.memory_stack[-1] > self.memory_stack[-2]:
continue
elif self.memory_stack[-1] < self.memory_stack[-4]:
print("Early Stopping Callback triggered, stopping training...")
# break out of training loop
break
worker.acc = self.memory_stack[-4]
# dump val accuracy and MLH
mlh = benford_r2_model(model)
dump_dict = {"Accuracy": worker.acc, "MLH": mlh}
filename = int(time())
out = open(f"{filename}", "a")
json.dump(dump_dict, out)
out.close()
def train(model, train_queue, criterion, optimizer, device):
avg_loss = 0
avg_acc = 0
batch_num = len(train_queue)
model.train()
for batch, (input, target) in enumerate(train_queue):
input = Variable(input, requires_grad=False).to(device)
target = Variable(target, requires_grad=False).to(device)
optimizer.zero_grad()
logits = model(input)
loss = criterion(logits, target)
loss.backward()
optimizer.step()
acc = utils.accuracy(logits.data, target.data)[0]
avg_loss += float(loss)
avg_acc += float(acc)
return avg_loss / batch_num, avg_acc / batch_num
def infer(model, valid_queue, criterion, device):
avg_loss = 0
avg_acc = 0
batch_num = len(valid_queue)
model.eval()
for batch, (input, target) in enumerate(valid_queue):
with torch.no_grad():
input = Variable(input).to(device)
target = Variable(target).to(device)
logits = model(input)
loss = criterion(logits, target)
acc = utils.accuracy(logits.data, target.data)[0]
avg_loss += float(loss)
avg_acc += float(acc)
return avg_loss / batch_num, avg_acc / batch_num