-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
65 lines (49 loc) · 1.98 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
"""
Created on Tue May 12 18:44:20 2020
@author: Kartik
"""
import os
import torch
import numpy as np
from tqdm import tqdm
from torchmeta.utils.prototype import get_prototypes, prototypical_loss
import model
import config
import dataloader
from model import get_accuracy
def load_model():
m = model.PrototypicalNetwork(config.in_channel, config.embedding_size, config.hidden_size)
m.load_state_dict(torch.load(os.path.join('saved_models', 'protonet_omniglot_5shot_5way.pt')))
return m
def test(device, testset, testloader, model):
model.to(device)
model.eval()
acc = []
with tqdm(testloader, total=config.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
train_inputs, train_targets = batch['train']
train_inputs = train_inputs.to(device=device)
train_targets = train_targets.to(device=device)
train_embeddings = model(train_inputs)
test_inputs, test_targets = batch['test']
test_inputs = test_inputs.to(device=device)
test_targets = test_targets.to(device=device)
test_embeddings = model(test_inputs)
prototypes = get_prototypes(train_embeddings, train_targets,
testset.num_classes_per_task)
with torch.no_grad():
accuracy = get_accuracy(prototypes, test_embeddings, test_targets)
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
acc.append(accuracy)
if batch_idx >= config.num_batches:
break
return acc
#main
trainset, trainloader = dataloader.load_meta_trainset()
testset, testloader = dataloader.load_meta_testset()
print("dataset loaded")
m = load_model()
print("Model loaded")
accuracy = test(config.device, testset, testloader, m)
print("Final accuracy: ", accuracy[-1])