-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformer_preds.py
97 lines (78 loc) · 3.4 KB
/
transformer_preds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import numpy as np
from ARCDataset import ARCTest
import argparse
import os
import torch.nn.functional as F
from matplotlib import colors
from glob import glob
from transformer_model import TransformerModel
from utils import seed_everything, plot_figure
seed_everything()
innerstepsize = 1e-2 # stepsize in inner SGD
innerepochs = 50 # number of epochs of each inner SGD
cmap = colors.ListedColormap(
['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
'#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25', '#FFFFFF'])
norm = colors.Normalize(vmin=0, vmax=10)
def main():
weights_dir = './model_weights'
os.makedirs('./model_preds', exist_ok=True)
print(args)
ntokens = 11 # the size of vocabulary
emsize = 32 # embedding dimension
nhid = 64 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 4 # the number of heads in the multiheadattention models
dropout = 0.5 # the dropout value
device = torch.device('cuda')
model = TransformerModel(ntokens, emsize, nhead,
nhid, nlayers, dropout).to(device)
def cond(x): return float(x.split('/')[-1].split('_')[-1][:-4])
all_model_fn = sorted(glob(f'./{weights_dir}/*.pth'), key=cond)[-1]
print('Using model weights from', all_model_fn)
# batchsz here means total episode number
arc_dataset = ARCTest(
root='/home/sid/Desktop/arc/data/', imgsz=args.imgsz)
all_train_acc = []
for step, ((x, y), q) in enumerate(zip(arc_dataset, arc_dataset.query_x_batch)):
# print('step:', step)
state = torch.load(all_model_fn)
model.load_state_dict(state)
optimizer = torch.optim.AdamW(model.parameters(), lr=innerstepsize)
x, y = x.to(device), y.to(device)
x = x.to(device).reshape(-1, args.imgsz*args.imgsz).long()
train_losses = []
train_acc = []
model.train()
for _ in range(innerepochs):
optimizer.zero_grad()
outputs = model(x).reshape(-1, args.num_class)
loss = F.cross_entropy(outputs, y)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
acc = (outputs.argmax(1) == y).float().mean().item()
train_acc.append(acc)
print('\ttraining loss:',
np.mean(train_losses), '\ttraining acc:', np.mean(train_acc))
all_train_acc.append(np.mean(train_acc))
model.eval()
with torch.no_grad():
q = torch.tensor(
q.reshape(-1, args.imgsz*args.imgsz)).to(device).long()
# print(q.shape)
outputs = F.softmax(model(q), dim=1)
outputs = outputs.argmax(2).reshape(-1, args.imgsz, args.imgsz)
plot_figure(x, y, q, outputs, im_num=step, img_sz=args.imgsz)
print('\nmean train acc:', np.mean(all_train_acc),
'stddev train acc:', np.std(all_train_acc))
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--epoch', type=int,
help='epoch number', default=501)
argparser.add_argument('--num_class', type=int,
help='number of classes', default=11)
argparser.add_argument('--imgsz', type=int, help='imgsz', default=15)
args = argparser.parse_args()
main()