forked from basiralab/RepNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gunet_trainer.py
70 lines (60 loc) · 2.48 KB
/
gunet_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.optim as optim
from utils.dataset import GraphData
from tqdm import tqdm
class Trainer:
def __init__(self, args, net, G_data):
self.args = args
self.net = net
self.feat_dim = G_data.feat_dim
self.fold_idx = G_data.fold_idx
self.init(args, G_data.train_gs, G_data.test_gs)
if torch.cuda.is_available():
self.net.cuda()
def init(self, args, train_gs, test_gs):
print('#train: %d, #test: %d' % (len(train_gs), len(test_gs)))
train_data = GraphData(train_gs, self.feat_dim)
test_data = GraphData(test_gs, self.feat_dim)
self.train_d = train_data.loader(self.args.batch, True)
self.test_d = test_data.loader(self.args.batch, False)
self.optimizer = optim.Adam(
self.net.parameters(), lr=self.args.lr, amsgrad=True,
weight_decay=0.0008)
def to_cuda(self, gs):
if torch.cuda.is_available():
if type(gs) == list:
return [g.cuda() for g in gs]
return gs.cuda()
return gs
def run_epoch(self, epoch, data, model, optimizer):
losses, accs, n_samples = [], [], 0
for batch in tqdm(data, desc=str(epoch), unit='b'):
cur_len, gs, hs, ys = batch
gs, hs, ys = map(self.to_cuda, [gs, hs, ys])
loss, acc = model(gs, hs, ys)
losses.append(loss*cur_len)
accs.append(acc*cur_len)
n_samples += cur_len
if optimizer is not None:
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samples
return avg_loss.item(), avg_acc.item()
def train(self):
max_acc = 0.0
train_str = 'Train epoch %d: loss %.5f acc %.5f'
test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f'
line_str = '%d:\t%.5f\n'
for e_id in range(self.args.num_epochs):
self.net.train()
loss, acc = self.run_epoch(
e_id, self.train_d, self.net, self.optimizer)
print(train_str % (e_id, loss, acc))
with torch.no_grad():
self.net.eval()
loss, acc = self.run_epoch(e_id, self.test_d, self.net, None)
max_acc = max(max_acc, acc)
print(test_str % (e_id, loss, acc, max_acc))
with open(self.args.acc_file, 'a+') as f:
f.write(line_str % (self.fold_idx, max_acc))